From 47b7ebad12a17218f6ca0301fc802c0e0a81d873 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 28 Jul 2012 20:03:26 -0700 Subject: Added the Spark Streaing code, ported to Akka 2 --- core/src/main/scala/spark/BlockRDD.scala | 42 ++++++++++++++++++++++++++++ core/src/main/scala/spark/SparkContext.scala | 5 ++++ 2 files changed, 47 insertions(+) create mode 100644 core/src/main/scala/spark/BlockRDD.scala (limited to 'core') diff --git a/core/src/main/scala/spark/BlockRDD.scala b/core/src/main/scala/spark/BlockRDD.scala new file mode 100644 index 0000000000..ea009f0f4f --- /dev/null +++ b/core/src/main/scala/spark/BlockRDD.scala @@ -0,0 +1,42 @@ +package spark + +import scala.collection.mutable.HashMap + +class BlockRDDSplit(val blockId: String, idx: Int) extends Split { + val index = idx +} + + +class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) { + + @transient + val splits_ = (0 until blockIds.size).map(i => { + new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] + }).toArray + + @transient + lazy val locations_ = { + val blockManager = SparkEnv.get.blockManager + /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ + val locations = blockManager.getLocations(blockIds) + HashMap(blockIds.zip(locations):_*) + } + + override def splits = splits_ + + override def compute(split: Split): Iterator[T] = { + val blockManager = SparkEnv.get.blockManager + val blockId = split.asInstanceOf[BlockRDDSplit].blockId + blockManager.get(blockId) match { + case Some(block) => block.asInstanceOf[Iterator[T]] + case None => + throw new Exception("Could not compute split, block " + blockId + " not found") + } + } + + override def preferredLocations(split: Split) = + locations_(split.asInstanceOf[BlockRDDSplit].blockId) + + override val dependencies: List[Dependency[_]] = Nil +} + diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index dd17d4d6b3..78c7618542 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -409,6 +409,11 @@ class SparkContext( * various Spark features. */ object SparkContext { + + // TODO: temporary hack for using HDFS as input in streaing + var inputFile: String = null + var idealPartitions: Int = 1 + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 -- cgit v1.2.3 From 3be54c2a8afcb2a3abf1cf22934123fae3419278 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 1 Aug 2012 22:09:27 -0700 Subject: 1. Refactored SparkStreamContext, Scheduler, InputRDS, FileInputRDS and a few other files. 2. Modified Time class to represent milliseconds (long) directly, instead of LongTime. 3. Added new files QueueInputRDS, RecurringTimer, etc. 4. Added RDDSuite as the skeleton for testcases. 5. Added two examples in spark.streaming.examples. 6. Removed all past examples and a few unnecessary files. Moved a number of files to spark.streaming.util. --- core/src/main/scala/spark/Utils.scala | 2 +- .../src/main/scala/spark/streaming/BlockID.scala | 20 -- .../main/scala/spark/streaming/FileInputRDS.scala | 163 +++++++++++ .../scala/spark/streaming/FileStreamReceiver.scala | 70 ----- .../scala/spark/streaming/IdealPerformance.scala | 36 --- .../src/main/scala/spark/streaming/Interval.scala | 6 +- streaming/src/main/scala/spark/streaming/Job.scala | 9 +- .../main/scala/spark/streaming/JobManager.scala | 23 +- .../spark/streaming/NetworkStreamReceiver.scala | 184 ------------- .../scala/spark/streaming/PairRDSFunctions.scala | 72 +++++ .../main/scala/spark/streaming/QueueInputRDS.scala | 36 +++ streaming/src/main/scala/spark/streaming/RDS.scala | 305 ++++++--------------- .../src/main/scala/spark/streaming/Scheduler.scala | 171 ++---------- .../scala/spark/streaming/SparkStreamContext.scala | 150 +++++++--- .../spark/streaming/TestInputBlockTracker.scala | 42 --- .../spark/streaming/TestStreamReceiver3.scala | 18 +- .../spark/streaming/TestStreamReceiver4.scala | 16 +- .../src/main/scala/spark/streaming/Time.scala | 100 ++++--- .../examples/DumbTopKWordCount2_Special.scala | 138 ---------- .../examples/DumbWordCount2_Special.scala | 92 ------- .../spark/streaming/examples/ExampleOne.scala | 41 +++ .../spark/streaming/examples/ExampleTwo.scala | 47 ++++ .../scala/spark/streaming/examples/GrepCount.scala | 39 --- .../spark/streaming/examples/GrepCount2.scala | 113 -------- .../spark/streaming/examples/GrepCountApprox.scala | 54 ---- .../spark/streaming/examples/SimpleWordCount.scala | 30 -- .../streaming/examples/SimpleWordCount2.scala | 51 ---- .../examples/SimpleWordCount2_Special.scala | 83 ------ .../spark/streaming/examples/TopContentCount.scala | 97 ------- .../spark/streaming/examples/TopKWordCount2.scala | 103 ------- .../examples/TopKWordCount2_Special.scala | 142 ---------- .../scala/spark/streaming/examples/WordCount.scala | 62 ----- .../spark/streaming/examples/WordCount1.scala | 46 ---- .../spark/streaming/examples/WordCount2.scala | 55 ---- .../streaming/examples/WordCount2_Special.scala | 94 ------- .../spark/streaming/examples/WordCount3.scala | 49 ---- .../spark/streaming/examples/WordCountEc2.scala | 41 --- .../examples/WordCountTrivialWindow.scala | 51 ---- .../scala/spark/streaming/examples/WordMax.scala | 64 ----- .../spark/streaming/util/RecurringTimer.scala | 52 ++++ .../spark/streaming/util/SenderReceiverTest.scala | 64 +++++ .../streaming/util/SentenceFileGenerator.scala | 92 +++++++ .../scala/spark/streaming/util/ShuffleTest.scala | 23 ++ .../main/scala/spark/streaming/util/Utils.scala | 9 + .../utils/SenGeneratorForPerformanceTest.scala | 78 ------ .../spark/streaming/utils/SenderReceiverTest.scala | 63 ----- .../streaming/utils/SentenceFileGenerator.scala | 92 ------- .../spark/streaming/utils/SentenceGenerator.scala | 103 ------- .../scala/spark/streaming/utils/ShuffleTest.scala | 22 -- .../src/test/scala/spark/streaming/RDSSuite.scala | 65 +++++ 50 files changed, 971 insertions(+), 2607 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/BlockID.scala create mode 100644 streaming/src/main/scala/spark/streaming/FileInputRDS.scala delete mode 100644 streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala delete mode 100644 streaming/src/main/scala/spark/streaming/IdealPerformance.scala delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala create mode 100644 streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala create mode 100644 streaming/src/main/scala/spark/streaming/QueueInputRDS.scala delete mode 100644 streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount1.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount3.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordMax.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/Utils.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala create mode 100644 streaming/src/test/scala/spark/streaming/RDSSuite.scala (limited to 'core') diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 5eda1011f9..1d33f7d6b3 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -185,7 +185,7 @@ object Utils { * millisecond. */ def getUsedTimeMs(startTimeMs: Long): String = { - return " " + (System.currentTimeMillis - startTimeMs) + " ms " + return " " + (System.currentTimeMillis - startTimeMs) + " ms" } /** diff --git a/streaming/src/main/scala/spark/streaming/BlockID.scala b/streaming/src/main/scala/spark/streaming/BlockID.scala deleted file mode 100644 index 16aacfda18..0000000000 --- a/streaming/src/main/scala/spark/streaming/BlockID.scala +++ /dev/null @@ -1,20 +0,0 @@ -package spark.streaming - -case class BlockID(sRds: String, sInterval: Interval, sPartition: Int) { - override def toString : String = ( - sRds + BlockID.sConnector + - sInterval.beginTime + BlockID.sConnector + - sInterval.endTime + BlockID.sConnector + - sPartition - ) -} - -object BlockID { - val sConnector = '-' - - def parse(name : String) = BlockID( - name.split(BlockID.sConnector)(0), - new Interval(name.split(BlockID.sConnector)(1).toLong, - name.split(BlockID.sConnector)(2).toLong), - name.split(BlockID.sConnector)(3).toInt) -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/FileInputRDS.scala b/streaming/src/main/scala/spark/streaming/FileInputRDS.scala new file mode 100644 index 0000000000..dde80cd27a --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/FileInputRDS.scala @@ -0,0 +1,163 @@ +package spark.streaming + +import spark.SparkContext +import spark.RDD +import spark.BlockRDD +import spark.UnionRDD +import spark.storage.StorageLevel +import spark.streaming._ + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import java.net.InetSocketAddress + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.PathFilter +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} + + +class FileInputRDS[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( + ssc: SparkStreamContext, + directory: Path, + filter: PathFilter = FileInputRDS.defaultPathFilter, + newFilesOnly: Boolean = true) + extends InputRDS[(K, V)](ssc) { + + val fs = directory.getFileSystem(new Configuration()) + var lastModTime: Long = 0 + + override def start() { + if (newFilesOnly) { + lastModTime = System.currentTimeMillis() + } else { + lastModTime = 0 + } + } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val newFilter = new PathFilter() { + var latestModTime = 0L + + def accept(path: Path): Boolean = { + + if (!filter.accept(path)) { + return false + } else { + val modTime = fs.getFileStatus(path).getModificationTime() + if (modTime < lastModTime) { + return false + } + if (modTime > latestModTime) { + latestModTime = modTime + } + return true + } + } + } + + val newFiles = fs.listStatus(directory, newFilter) + lastModTime = newFilter.latestModTime + val newRDD = new UnionRDD(ssc.sc, newFiles.map(file => + ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString)) + ) + Some(newRDD) + } +} + +object FileInputRDS { + val defaultPathFilter = new PathFilter { + def accept(path: Path): Boolean = { + val file = path.getName() + if (file.startsWith(".") || file.endsWith("_tmp")) { + return false + } else { + return true + } + } + } +} + +/* +class NetworkInputRDS[T: ClassManifest]( + val networkInputName: String, + val addresses: Array[InetSocketAddress], + batchDuration: Time, + ssc: SparkStreamContext) +extends InputRDS[T](networkInputName, batchDuration, ssc) { + + + // TODO(Haoyuan): This is for the performance test. + @transient var rdd: RDD[T] = null + + if (System.getProperty("spark.fake", "false") == "true") { + logInfo("Running initial count to cache fake RDD") + rdd = ssc.sc.textFile(SparkContext.inputFile, + SparkContext.idealPartitions).asInstanceOf[RDD[T]] + val fakeCacheLevel = System.getProperty("spark.fake.cache", "") + if (fakeCacheLevel == "MEMORY_ONLY_2") { + rdd.persist(StorageLevel.MEMORY_ONLY_2) + } else if (fakeCacheLevel == "MEMORY_ONLY_DESER_2") { + rdd.persist(StorageLevel.MEMORY_ONLY_2) + } else if (fakeCacheLevel != "") { + logError("Invalid fake cache level: " + fakeCacheLevel) + System.exit(1) + } + rdd.count() + } + + @transient val references = new HashMap[Time,String] + + override def compute(validTime: Time): Option[RDD[T]] = { + if (System.getProperty("spark.fake", "false") == "true") { + logInfo("Returning fake RDD at " + validTime) + return Some(rdd) + } + references.get(validTime) match { + case Some(reference) => + if (reference.startsWith("file") || reference.startsWith("hdfs")) { + logInfo("Reading from file " + reference + " for time " + validTime) + Some(ssc.sc.textFile(reference).asInstanceOf[RDD[T]]) + } else { + logInfo("Getting from BlockManager " + reference + " for time " + validTime) + Some(new BlockRDD(ssc.sc, Array(reference))) + } + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + references += ((time, reference.toString)) + logInfo("Reference added for time " + time + " - " + reference.toString) + } +} + + +class TestInputRDS( + val testInputName: String, + batchDuration: Time, + ssc: SparkStreamContext) +extends InputRDS[String](testInputName, batchDuration, ssc) { + + @transient val references = new HashMap[Time,Array[String]] + + override def compute(validTime: Time): Option[RDD[String]] = { + references.get(validTime) match { + case Some(reference) => + Some(new BlockRDD[String](ssc.sc, reference)) + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + references += ((time, reference.asInstanceOf[Array[String]])) + } +} +*/ diff --git a/streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala b/streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala deleted file mode 100644 index 92c7cfe00c..0000000000 --- a/streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala +++ /dev/null @@ -1,70 +0,0 @@ -package spark.streaming - -import spark.Logging - -import scala.collection.mutable.HashSet -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -class FileStreamReceiver ( - inputName: String, - rootDirectory: String, - intervalDuration: Long) - extends Logging { - - val pollInterval = 100 - val sparkstreamScheduler = { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt + 1 - RemoteActor.select(Node(host, port), 'SparkStreamScheduler) - } - val directory = new Path(rootDirectory) - val fs = directory.getFileSystem(new Configuration()) - val files = new HashSet[String]() - var time: Long = 0 - - def start() { - fs.mkdirs(directory) - files ++= getFiles() - - actor { - logInfo("Monitoring directory - " + rootDirectory) - while(true) { - testFiles(getFiles()) - Thread.sleep(pollInterval) - } - } - } - - def getFiles(): Iterable[String] = { - fs.listStatus(directory).map(_.getPath.toString) - } - - def testFiles(fileList: Iterable[String]) { - fileList.foreach(file => { - if (!files.contains(file)) { - if (!file.endsWith("_tmp")) { - notifyFile(file) - } - files += file - } - }) - } - - def notifyFile(file: String) { - logInfo("Notifying file " + file) - time += intervalDuration - val interval = Interval(LongTime(time), LongTime(time + intervalDuration)) - sparkstreamScheduler ! InputGenerated(inputName, interval, file) - } -} - - diff --git a/streaming/src/main/scala/spark/streaming/IdealPerformance.scala b/streaming/src/main/scala/spark/streaming/IdealPerformance.scala deleted file mode 100644 index 303d4e7ae6..0000000000 --- a/streaming/src/main/scala/spark/streaming/IdealPerformance.scala +++ /dev/null @@ -1,36 +0,0 @@ -package spark.streaming - -import scala.collection.mutable.Map - -object IdealPerformance { - val base: String = "The medium researcher counts around the pinched troop The empire breaks " + - "Matei Matei announces HY with a theorem " - - def main (args: Array[String]) { - val sentences: String = base * 100000 - - for (i <- 1 to 30) { - val start = System.nanoTime - - val words = sentences.split(" ") - - val pairs = words.map(word => (word, 1)) - - val counts = Map[String, Int]() - - println("Job " + i + " position A at " + (System.nanoTime - start) / 1e9) - - pairs.foreach((pair) => { - var t = counts.getOrElse(pair._1, 0) - counts(pair._1) = t + pair._2 - }) - println("Job " + i + " position B at " + (System.nanoTime - start) / 1e9) - - for ((word, count) <- counts) { - print(word + " " + count + "; ") - } - println - println("Job " + i + " finished in " + (System.nanoTime - start) / 1e9) - } - } -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala index 9a61d85274..1960097216 100644 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -2,7 +2,7 @@ package spark.streaming case class Interval (val beginTime: Time, val endTime: Time) { - def this(beginMs: Long, endMs: Long) = this(new LongTime(beginMs), new LongTime(endMs)) + def this(beginMs: Long, endMs: Long) = this(Time(beginMs), new Time(endMs)) def duration(): Time = endTime - beginTime @@ -44,8 +44,8 @@ object Interval { def zero() = new Interval (Time.zero, Time.zero) - def currentInterval(intervalDuration: LongTime): Interval = { - val time = LongTime(System.currentTimeMillis) + def currentInterval(intervalDuration: Time): Interval = { + val time = Time(System.currentTimeMillis) val intervalBegin = time.floor(intervalDuration) Interval(intervalBegin, intervalBegin + intervalDuration) } diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala index f7654dff79..36958dafe1 100644 --- a/streaming/src/main/scala/spark/streaming/Job.scala +++ b/streaming/src/main/scala/spark/streaming/Job.scala @@ -1,13 +1,14 @@ package spark.streaming +import spark.streaming.util.Utils + class Job(val time: Time, func: () => _) { val id = Job.getNewId() - - def run() { - func() + def run(): Long = { + Utils.time { func() } } - override def toString = "SparkStream Job " + id + ":" + time + override def toString = "streaming job " + id + " @ " + time } object Job { diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index d7d88a7000..43d167f7db 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -1,6 +1,7 @@ package spark.streaming -import spark.{Logging, SparkEnv} +import spark.Logging +import spark.SparkEnv import java.util.concurrent.Executors @@ -10,19 +11,14 @@ class JobManager(ssc: SparkStreamContext, numThreads: Int = 1) extends Logging { def run() { SparkEnv.set(ssc.env) try { - logInfo("Starting " + job) - job.run() - logInfo("Finished " + job) - if (job.time.isInstanceOf[LongTime]) { - val longTime = job.time.asInstanceOf[LongTime] - logInfo("Total notification + skew + processing delay for " + longTime + " is " + - (System.currentTimeMillis - longTime.milliseconds) / 1000.0 + " s") - if (System.getProperty("spark.stream.distributed", "false") == "true") { - TestInputBlockTracker.setEndTime(job.time) - } - } + val timeTaken = job.run() + logInfo( + "Runnning " + job + " took " + timeTaken + " ms, " + + "total delay was " + (System.currentTimeMillis - job.time) + " ms" + ) } catch { - case e: Exception => logError("SparkStream job failed", e) + case e: Exception => + logError("Running " + job + " failed", e) } } } @@ -33,5 +29,6 @@ class JobManager(ssc: SparkStreamContext, numThreads: Int = 1) extends Logging { def runJob(job: Job) { jobExecutor.execute(new JobHandler(ssc, job)) + logInfo("Added " + job + " to queue") } } diff --git a/streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala b/streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala deleted file mode 100644 index efd4689cf0..0000000000 --- a/streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala +++ /dev/null @@ -1,184 +0,0 @@ -package spark.streaming - -import spark.Logging -import spark.storage.StorageLevel - -import scala.math._ -import scala.collection.mutable.{Queue, HashMap, ArrayBuffer} -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.io.BufferedWriter -import java.io.OutputStreamWriter - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -/*import akka.actor.Actor._*/ - -class NetworkStreamReceiver[T: ClassManifest] ( - inputName: String, - intervalDuration: Time, - splitId: Int, - ssc: SparkStreamContext, - tempDirectory: String) - extends DaemonActor - with Logging { - - /** - * Assume all data coming in has non-decreasing timestamp. - */ - final class Inbox[T: ClassManifest] (intervalDuration: Time) { - var currentBucket: (Interval, ArrayBuffer[T]) = null - val filledBuckets = new Queue[(Interval, ArrayBuffer[T])]() - - def += (tuple: (Time, T)) = addTuple(tuple) - - def addTuple(tuple: (Time, T)) { - val (time, data) = tuple - val interval = getInterval (time) - - filledBuckets.synchronized { - if (currentBucket == null) { - currentBucket = (interval, new ArrayBuffer[T]()) - } - - if (interval != currentBucket._1) { - filledBuckets += currentBucket - currentBucket = (interval, new ArrayBuffer[T]()) - } - - currentBucket._2 += data - } - } - - def getInterval(time: Time): Interval = { - val intervalBegin = time.floor(intervalDuration) - Interval (intervalBegin, intervalBegin + intervalDuration) - } - - def hasFilledBuckets(): Boolean = { - filledBuckets.synchronized { - return filledBuckets.size > 0 - } - } - - def popFilledBucket(): (Interval, ArrayBuffer[T]) = { - filledBuckets.synchronized { - if (filledBuckets.size == 0) { - return null - } - return filledBuckets.dequeue() - } - } - } - - val inbox = new Inbox[T](intervalDuration) - lazy val sparkstreamScheduler = { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt - val url = "akka://spark@%s:%s/user/SparkStreamScheduler".format(host, port) - ssc.actorSystem.actorFor(url) - } - /*sparkstreamScheduler ! Test()*/ - - val intervalDurationMillis = intervalDuration.asInstanceOf[LongTime].milliseconds - val useBlockManager = true - - initLogging() - - override def act() { - // register the InputReceiver - val port = 7078 - RemoteActor.alive(port) - RemoteActor.register(Symbol("NetworkStreamReceiver-"+inputName), self) - logInfo("Registered actor on port " + port) - - loop { - reactWithin (getSleepTime) { - case TIMEOUT => - flushInbox() - case data => - val t = data.asInstanceOf[T] - inbox += (getTimeFromData(t), t) - } - } - } - - def getSleepTime(): Long = { - (System.currentTimeMillis / intervalDurationMillis + 1) * - intervalDurationMillis - System.currentTimeMillis - } - - def getTimeFromData(data: T): Time = { - LongTime(System.currentTimeMillis) - } - - def flushInbox() { - while (inbox.hasFilledBuckets) { - inbox.synchronized { - val (interval, data) = inbox.popFilledBucket() - val dataArray = data.toArray - logInfo("Received " + dataArray.length + " items at interval " + interval) - val reference = { - if (useBlockManager) { - writeToBlockManager(dataArray, interval) - } else { - writeToDisk(dataArray, interval) - } - } - if (reference != null) { - logInfo("Notifying scheduler") - sparkstreamScheduler ! InputGenerated(inputName, interval, reference.toString) - } - } - } - } - - def writeToDisk(data: Array[T], interval: Interval): String = { - try { - // TODO(Haoyuan): For current test, the following writing to file lines could be - // commented. - val fs = new Path(tempDirectory).getFileSystem(new Configuration()) - val inputDir = new Path( - tempDirectory, - inputName + "-" + interval.toFormattedString) - val inputFile = new Path(inputDir, "part-" + splitId) - logInfo("Writing to file " + inputFile) - if (System.getProperty("spark.fake", "false") != "true") { - val writer = new BufferedWriter(new OutputStreamWriter(fs.create(inputFile, true))) - data.foreach(x => writer.write(x.toString + "\n")) - writer.close() - } else { - logInfo("Fake file") - } - inputFile.toString - }catch { - case e: Exception => - logError("Exception writing to file at interval " + interval + ": " + e.getMessage, e) - null - } - } - - def writeToBlockManager(data: Array[T], interval: Interval): String = { - try{ - val blockId = inputName + "-" + interval.toFormattedString + "-" + splitId - if (System.getProperty("spark.fake", "false") != "true") { - logInfo("Writing as block " + blockId ) - ssc.env.blockManager.put(blockId.toString, data.toIterator, StorageLevel.DISK_AND_MEMORY) - } else { - logInfo("Fake block") - } - blockId - } catch { - case e: Exception => - logError("Exception writing to block manager at interval " + interval + ": " + e.getMessage, e) - null - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala b/streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala new file mode 100644 index 0000000000..403ae233a5 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala @@ -0,0 +1,72 @@ +package spark.streaming + +import scala.collection.mutable.ArrayBuffer +import spark.streaming.SparkStreamContext._ + +class PairRDSFunctions[K: ClassManifest, V: ClassManifest](rds: RDS[(K,V)]) +extends Serializable { + + def ssc = rds.ssc + + /* ---------------------------------- */ + /* RDS operations for key-value pairs */ + /* ---------------------------------- */ + + def groupByKey(numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { + def createCombiner(v: V) = ArrayBuffer[V](v) + def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) + def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) + combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, numPartitions) + } + + def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int = 0): ShuffledRDS[K, V, V] = { + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, numPartitions) + } + + private def combineByKey[C: ClassManifest]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + numPartitions: Int) : ShuffledRDS[K, V, C] = { + new ShuffledRDS[K, V, C](rds, createCombiner, mergeValue, mergeCombiner, numPartitions) + } + + def groupByKeyAndWindow( + windowTime: Time, + slideTime: Time, + numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { + rds.window(windowTime, slideTime).groupByKey(numPartitions) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int = 0): ShuffledRDS[K, V, V] = { + rds.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), numPartitions) + } + + // This method is the efficient sliding window reduce operation, + // which requires the specification of an inverse reduce function, + // so that new elements introduced in the window can be "added" using + // reduceFunc to the previous window's result and old elements can be + // "subtracted using invReduceFunc. + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int): ReducedWindowedRDS[K, V] = { + + new ReducedWindowedRDS[K, V]( + rds, + ssc.sc.clean(reduceFunc), + ssc.sc.clean(invReduceFunc), + windowTime, + slideTime, + numPartitions) + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/QueueInputRDS.scala b/streaming/src/main/scala/spark/streaming/QueueInputRDS.scala new file mode 100644 index 0000000000..31e6a64e21 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/QueueInputRDS.scala @@ -0,0 +1,36 @@ +package spark.streaming + +import spark.RDD +import spark.UnionRDD + +import scala.collection.mutable.Queue +import scala.collection.mutable.ArrayBuffer + +class QueueInputRDS[T: ClassManifest]( + ssc: SparkStreamContext, + val queue: Queue[RDD[T]], + oneAtATime: Boolean, + defaultRDD: RDD[T] + ) extends InputRDS[T](ssc) { + + override def start() { } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[T]] = { + val buffer = new ArrayBuffer[RDD[T]]() + if (oneAtATime && queue.size > 0) { + buffer += queue.dequeue() + } else { + buffer ++= queue + } + if (buffer.size > 0) { + Some(new UnionRDD(ssc.sc, buffer.toSeq)) + } else if (defaultRDD != null) { + Some(defaultRDD) + } else { + None + } + } + +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/RDS.scala b/streaming/src/main/scala/spark/streaming/RDS.scala index c8dd1015ed..fd923929e7 100644 --- a/streaming/src/main/scala/spark/streaming/RDS.scala +++ b/streaming/src/main/scala/spark/streaming/RDS.scala @@ -13,16 +13,18 @@ import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import java.net.InetSocketAddress +import java.util.concurrent.ArrayBlockingQueue abstract class RDS[T: ClassManifest] (@transient val ssc: SparkStreamContext) extends Logging with Serializable { initLogging() - /* ---------------------------------------------- */ - /* Methods that must be implemented by subclasses */ - /* ---------------------------------------------- */ + /** + * ---------------------------------------------- + * Methods that must be implemented by subclasses + * ---------------------------------------------- + */ // Time by which the window slides in this RDS def slideTime: Time @@ -33,9 +35,11 @@ extends Logging with Serializable { // Key method that computes RDD for a valid time def compute (validTime: Time): Option[RDD[T]] - /* --------------------------------------- */ - /* Other general fields and methods of RDS */ - /* --------------------------------------- */ + /** + * --------------------------------------- + * Other general fields and methods of RDS + * --------------------------------------- + */ // Variable to store the RDDs generated earlier in time @transient private val generatedRDDs = new HashMap[Time, RDD[T]] () @@ -66,9 +70,9 @@ extends Logging with Serializable { this } + // Set caching level for the RDDs created by this RDS def persist(newLevel: StorageLevel): RDS[T] = persist(newLevel, StorageLevel.NONE, null) - // Turn on the default caching level for this RDD def persist(): RDS[T] = persist(StorageLevel.MEMORY_ONLY_DESER) // Turn on the default caching level for this RDD @@ -76,18 +80,20 @@ extends Logging with Serializable { def isInitialized = (zeroTime != null) - // This method initializes the RDS by setting the "zero" time, based on which - // the validity of future times is calculated. This method also recursively initializes - // its parent RDSs. - def initialize(firstInterval: Interval) { + /** + * This method initializes the RDS by setting the "zero" time, based on which + * the validity of future times is calculated. This method also recursively initializes + * its parent RDSs. + */ + def initialize(time: Time) { if (zeroTime == null) { - zeroTime = firstInterval.beginTime + zeroTime = time } logInfo(this + " initialized") - dependencies.foreach(_.initialize(firstInterval)) + dependencies.foreach(_.initialize(zeroTime)) } - // This method checks whether the 'time' is valid wrt slideTime for generating RDD + /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ private def isTimeValid (time: Time): Boolean = { if (!isInitialized) throw new Exception (this.toString + " has not been initialized") @@ -98,11 +104,13 @@ extends Logging with Serializable { } } - // This method either retrieves a precomputed RDD of this RDS, - // or computes the RDD (if the time is valid) + /** + * This method either retrieves a precomputed RDD of this RDS, + * or computes the RDD (if the time is valid) + */ def getOrCompute(time: Time): Option[RDD[T]] = { - - // if RDD was already generated, then retrieve it from HashMap + // If this RDS was not initialized (i.e., zeroTime not set), then do it + // If RDD was already generated, then retrieve it from HashMap generatedRDDs.get(time) match { // If an RDD was already generated and is being reused, then @@ -115,15 +123,12 @@ extends Logging with Serializable { if (isTimeValid(time)) { compute(time) match { case Some(newRDD) => - if (System.getProperty("spark.fake", "false") != "true" || - newRDD.getStorageLevel == StorageLevel.NONE) { - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) - } + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.persist(checkpointLevel) + logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) + } else if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) } generatedRDDs.put(time.copy(), newRDD) Some(newRDD) @@ -136,8 +141,10 @@ extends Logging with Serializable { } } - // This method generates a SparkStream job for the given time - // and may require to be overriden by subclasses + /** + * This method generates a SparkStream job for the given time + * and may require to be overriden by subclasses + */ def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { case Some(rdd) => { @@ -151,9 +158,11 @@ extends Logging with Serializable { } } - /* -------------- */ - /* RDS operations */ - /* -------------- */ + /** + * -------------- + * RDS operations + * -------------- + */ def map[U: ClassManifest](mapFunc: T => U) = new MappedRDS(this, ssc.sc.clean(mapFunc)) @@ -185,6 +194,15 @@ extends Logging with Serializable { newrds } + private[streaming] def toQueue() = { + val queue = new ArrayBlockingQueue[RDD[T]](10000) + this.foreachRDD(rdd => { + println("Added RDD " + rdd.id) + queue.add(rdd) + }) + queue + } + def print() = { def foreachFunc = (rdd: RDD[T], time: Time) => { val first11 = rdd.take(11) @@ -229,198 +247,23 @@ extends Logging with Serializable { } -class PairRDSFunctions[K: ClassManifest, V: ClassManifest](rds: RDS[(K,V)]) -extends Serializable { - - def ssc = rds.ssc - - /* ---------------------------------- */ - /* RDS operations for key-value pairs */ - /* ---------------------------------- */ - - def groupByKey(numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { - def createCombiner(v: V) = ArrayBuffer[V](v) - def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) - def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) - combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, numPartitions) - } - - def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int = 0): ShuffledRDS[K, V, V] = { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, numPartitions) - } - - private def combineByKey[C: ClassManifest]( - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - numPartitions: Int) : ShuffledRDS[K, V, C] = { - new ShuffledRDS[K, V, C](rds, createCombiner, mergeValue, mergeCombiner, numPartitions) - } - - def groupByKeyAndWindow( - windowTime: Time, - slideTime: Time, - numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { - rds.window(windowTime, slideTime).groupByKey(numPartitions) - } - - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, - numPartitions: Int = 0): ShuffledRDS[K, V, V] = { - rds.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), numPartitions) - } - - // This method is the efficient sliding window reduce operation, - // which requires the specification of an inverse reduce function, - // so that new elements introduced in the window can be "added" using - // reduceFunc to the previous window's result and old elements can be - // "subtracted using invReduceFunc. - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, - numPartitions: Int): ReducedWindowedRDS[K, V] = { - - new ReducedWindowedRDS[K, V]( - rds, - ssc.sc.clean(reduceFunc), - ssc.sc.clean(invReduceFunc), - windowTime, - slideTime, - numPartitions) - } -} - - abstract class InputRDS[T: ClassManifest] ( - val inputName: String, - val batchDuration: Time, ssc: SparkStreamContext) extends RDS[T](ssc) { override def dependencies = List() - override def slideTime = batchDuration + override def slideTime = ssc.batchDuration - def setReference(time: Time, reference: AnyRef) -} - - -class FileInputRDS( - val fileInputName: String, - val directory: String, - ssc: SparkStreamContext) -extends InputRDS[String](fileInputName, LongTime(1000), ssc) { - - @transient val generatedFiles = new HashMap[Time,String] - - // TODO(Haoyuan): This is for the performance test. - @transient - val rdd = ssc.sc.textFile(SparkContext.inputFile, - SparkContext.idealPartitions).asInstanceOf[RDD[String]] + def start() - override def compute(validTime: Time): Option[RDD[String]] = { - generatedFiles.get(validTime) match { - case Some(file) => - logInfo("Reading from file " + file + " for time " + validTime) - // Some(ssc.sc.textFile(file).asInstanceOf[RDD[String]]) - // The following line is for HDFS performance test. Sould comment out the above line. - Some(rdd) - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - generatedFiles += ((time, reference.toString)) - logInfo("Reference added for time " + time + " - " + reference.toString) - } -} - -class NetworkInputRDS[T: ClassManifest]( - val networkInputName: String, - val addresses: Array[InetSocketAddress], - batchDuration: Time, - ssc: SparkStreamContext) -extends InputRDS[T](networkInputName, batchDuration, ssc) { - - - // TODO(Haoyuan): This is for the performance test. - @transient var rdd: RDD[T] = null - - if (System.getProperty("spark.fake", "false") == "true") { - logInfo("Running initial count to cache fake RDD") - rdd = ssc.sc.textFile(SparkContext.inputFile, - SparkContext.idealPartitions).asInstanceOf[RDD[T]] - val fakeCacheLevel = System.getProperty("spark.fake.cache", "") - if (fakeCacheLevel == "MEMORY_ONLY_2") { - rdd.persist(StorageLevel.MEMORY_ONLY_2) - } else if (fakeCacheLevel == "MEMORY_ONLY_DESER_2") { - rdd.persist(StorageLevel.MEMORY_ONLY_2) - } else if (fakeCacheLevel != "") { - logError("Invalid fake cache level: " + fakeCacheLevel) - System.exit(1) - } - rdd.count() - } - - @transient val references = new HashMap[Time,String] - - override def compute(validTime: Time): Option[RDD[T]] = { - if (System.getProperty("spark.fake", "false") == "true") { - logInfo("Returning fake RDD at " + validTime) - return Some(rdd) - } - references.get(validTime) match { - case Some(reference) => - if (reference.startsWith("file") || reference.startsWith("hdfs")) { - logInfo("Reading from file " + reference + " for time " + validTime) - Some(ssc.sc.textFile(reference).asInstanceOf[RDD[T]]) - } else { - logInfo("Getting from BlockManager " + reference + " for time " + validTime) - Some(new BlockRDD(ssc.sc, Array(reference))) - } - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - references += ((time, reference.toString)) - logInfo("Reference added for time " + time + " - " + reference.toString) - } + def stop() } -class TestInputRDS( - val testInputName: String, - batchDuration: Time, - ssc: SparkStreamContext) -extends InputRDS[String](testInputName, batchDuration, ssc) { - - @transient val references = new HashMap[Time,Array[String]] - - override def compute(validTime: Time): Option[RDD[String]] = { - references.get(validTime) match { - case Some(reference) => - Some(new BlockRDD[String](ssc.sc, reference)) - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - references += ((time, reference.asInstanceOf[Array[String]])) - } -} - +/** + * TODO + */ class MappedRDS[T: ClassManifest, U: ClassManifest] ( parent: RDS[T], @@ -437,6 +280,10 @@ extends RDS[U](parent.ssc) { } +/** + * TODO + */ + class FlatMappedRDS[T: ClassManifest, U: ClassManifest]( parent: RDS[T], flatMapFunc: T => Traversable[U]) @@ -452,6 +299,10 @@ extends RDS[U](parent.ssc) { } +/** + * TODO + */ + class FilteredRDS[T: ClassManifest](parent: RDS[T], filterFunc: T => Boolean) extends RDS[T](parent.ssc) { @@ -464,6 +315,11 @@ extends RDS[T](parent.ssc) { } } + +/** + * TODO + */ + class MapPartitionedRDS[T: ClassManifest, U: ClassManifest]( parent: RDS[T], mapPartFunc: Iterator[T] => Iterator[U]) @@ -478,6 +334,11 @@ extends RDS[U](parent.ssc) { } } + +/** + * TODO + */ + class GlommedRDS[T: ClassManifest](parent: RDS[T]) extends RDS[Array[T]](parent.ssc) { override def dependencies = List(parent) @@ -490,6 +351,10 @@ class GlommedRDS[T: ClassManifest](parent: RDS[T]) extends RDS[Array[T]](parent. } +/** + * TODO + */ + class ShuffledRDS[K: ClassManifest, V: ClassManifest, C: ClassManifest]( parent: RDS[(K,V)], createCombiner: V => C, @@ -519,6 +384,10 @@ class ShuffledRDS[K: ClassManifest, V: ClassManifest, C: ClassManifest]( } +/** + * TODO + */ + class UnifiedRDS[T: ClassManifest](parents: Array[RDS[T]]) extends RDS[T](parents(0).ssc) { @@ -553,6 +422,10 @@ extends RDS[T](parents(0).ssc) { } +/** + * TODO + */ + class PerElementForEachRDS[T: ClassManifest] ( parent: RDS[T], foreachFunc: T => Unit) @@ -580,6 +453,10 @@ extends RDS[Unit](parent.ssc) { } +/** + * TODO + */ + class PerRDDForEachRDS[T: ClassManifest] ( parent: RDS[T], foreachFunc: (RDD[T], Time) => Unit) diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 8df346559c..83f874e550 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -1,16 +1,11 @@ package spark.streaming +import spark.streaming.util.RecurringTimer import spark.SparkEnv import spark.Logging import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.collection.mutable.ArrayBuffer -import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ -import akka.util.duration._ sealed trait SchedulerMessage case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage @@ -20,162 +15,42 @@ class Scheduler( ssc: SparkStreamContext, inputRDSs: Array[InputRDS[_]], outputRDSs: Array[RDS[_]]) -extends Actor with Logging { - - class InputState (inputNames: Array[String]) { - val inputsLeft = new HashSet[String]() - inputsLeft ++= inputNames - - val startTime = System.currentTimeMillis - - def delay() = System.currentTimeMillis - startTime - - def addGeneratedInput(inputName: String) = inputsLeft -= inputName - - def areAllInputsGenerated() = (inputsLeft.size == 0) - - override def toString(): String = { - val left = if (inputsLeft.size == 0) "" else inputsLeft.reduceLeft(_ + ", " + _) - return "Inputs left = [ " + left + " ]" - } - } - +extends Logging { initLogging() - val inputNames = inputRDSs.map(_.inputName).toArray - val inputStates = new HashMap[Interval, InputState]() - val currentJobs = System.getProperty("spark.stream.currentJobs", "1").toInt - val jobManager = new JobManager(ssc, currentJobs) - - // TODO(Haoyuan): The following line is for performance test only. - var cnt: Int = System.getProperty("spark.stream.fake.cnt", "60").toInt - var lastInterval: Interval = null - - - /*remote.register("SparkStreamScheduler", actorOf[Scheduler])*/ - logInfo("Registered actor on port ") - - /*jobManager.start()*/ - startStreamReceivers() - - def receive = { - case InputGenerated(inputName, interval, reference) => { - addGeneratedInput(inputName, interval, reference) - } - case Test() => logInfo("TEST PASSED") - } - - def addGeneratedInput(inputName: String, interval: Interval, reference: AnyRef = null) { - logInfo("Input " + inputName + " generated for interval " + interval) - inputStates.get(interval) match { - case None => inputStates.put(interval, new InputState(inputNames)) - case _ => - } - inputStates(interval).addGeneratedInput(inputName) + val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt + val jobManager = new JobManager(ssc, concurrentJobs) + val timer = new RecurringTimer(ssc.batchDuration, generateRDDs(_)) - inputRDSs.filter(_.inputName == inputName).foreach(inputRDS => { - inputRDS.setReference(interval.endTime, reference) - if (inputRDS.isInstanceOf[TestInputRDS]) { - TestInputBlockTracker.addBlocks(interval.endTime, reference) - } - } - ) - - def getNextInterval(): Option[Interval] = { - logDebug("Last interval is " + lastInterval) - val readyIntervals = inputStates.filter(_._2.areAllInputsGenerated).keys - /*inputState.foreach(println) */ - logDebug("InputState has " + inputStates.size + " intervals, " + readyIntervals.size + " ready intervals") - return readyIntervals.find(lastInterval == null || _.beginTime == lastInterval.endTime) - } - - var nextInterval = getNextInterval() - var count = 0 - while(nextInterval.isDefined) { - val inputState = inputStates.get(nextInterval.get).get - generateRDDsForInterval(nextInterval.get) - logInfo("Skew delay for " + nextInterval.get.endTime + " is " + (inputState.delay / 1000.0) + " s") - inputStates.remove(nextInterval.get) - lastInterval = nextInterval.get - nextInterval = getNextInterval() - count += 1 - /*if (nextInterval.size == 0 && inputState.size > 0) { - logDebug("Next interval not ready, pending intervals " + inputState.size) - }*/ - } - logDebug("RDDs generated for " + count + " intervals") - - /* - if (inputState(interval).areAllInputsGenerated) { - generateRDDsForInterval(interval) - lastInterval = interval - inputState.remove(interval) - } else { - logInfo("All inputs not generated for interval " + interval) - } - */ + def start() { + + val zeroTime = Time(timer.start()) + outputRDSs.foreach(_.initialize(zeroTime)) + inputRDSs.par.foreach(_.start()) + logInfo("Scheduler started") } - - def generateRDDsForInterval (interval: Interval) { - logInfo("Generating RDDs for interval " + interval) + + def stop() { + timer.stop() + inputRDSs.par.foreach(_.stop()) + logInfo("Scheduler stopped") + } + + def generateRDDs (time: Time) { + logInfo("Generating RDDs for time " + time) outputRDSs.foreach(outputRDS => { - if (!outputRDS.isInitialized) outputRDS.initialize(interval) - outputRDS.generateJob(interval.endTime) match { + outputRDS.generateJob(time) match { case Some(job) => submitJob(job) case None => } } ) - // TODO(Haoyuan): This comment is for performance test only. - if (System.getProperty("spark.fake", "false") == "true" || System.getProperty("spark.stream.fake", "false") == "true") { - cnt -= 1 - if (cnt <= 0) { - logInfo("My time is up! " + cnt) - System.exit(1) - } - } + logInfo("Generated RDDs for time " + time) } - def submitJob(job: Job) { - logInfo("Submitting " + job + " to JobManager") - /*jobManager ! RunJob(job)*/ + def submitJob(job: Job) { jobManager.runJob(job) } - - def startStreamReceivers() { - val testStreamReceiverNames = new ArrayBuffer[(String, Long)]() - inputRDSs.foreach (inputRDS => { - inputRDS match { - case fileInputRDS: FileInputRDS => { - val fileStreamReceiver = new FileStreamReceiver( - fileInputRDS.inputName, - fileInputRDS.directory, - fileInputRDS.batchDuration.asInstanceOf[LongTime].milliseconds) - fileStreamReceiver.start() - } - case networkInputRDS: NetworkInputRDS[_] => { - val networkStreamReceiver = new NetworkStreamReceiver( - networkInputRDS.inputName, - networkInputRDS.batchDuration, - 0, - ssc, - if (ssc.tempDir == null) null else ssc.tempDir.toString) - networkStreamReceiver.start() - } - case testInputRDS: TestInputRDS => { - testStreamReceiverNames += - ((testInputRDS.inputName, testInputRDS.batchDuration.asInstanceOf[LongTime].milliseconds)) - } - } - }) - if (testStreamReceiverNames.size > 0) { - /*val testStreamCoordinator = new TestStreamCoordinator(testStreamReceiverNames.toArray)*/ - /*testStreamCoordinator.start()*/ - val actor = ssc.actorSystem.actorOf( - Props(new TestStreamCoordinator(testStreamReceiverNames.toArray)), - name = "TestStreamCoordinator") - } - } } diff --git a/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala b/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala index 51f8193740..d32f6d588c 100644 --- a/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala +++ b/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala @@ -1,22 +1,23 @@ package spark.streaming -import spark.SparkContext -import spark.SparkEnv -import spark.Utils +import spark.RDD import spark.Logging +import spark.SparkEnv +import spark.SparkContext import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.Queue -import java.net.InetSocketAddress import java.io.IOException -import java.util.UUID +import java.net.InetSocketAddress +import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration - -import akka.actor._ -import akka.actor.Actor -import akka.util.duration._ +import org.apache.hadoop.io.LongWritable +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat class SparkStreamContext ( master: String, @@ -24,30 +25,37 @@ class SparkStreamContext ( val sparkHome: String = null, val jars: Seq[String] = Nil) extends Logging { - + initLogging() val sc = new SparkContext(master, frameworkName, sparkHome, jars) val env = SparkEnv.get - val actorSystem = env.actorSystem - - @transient val inputRDSs = new ArrayBuffer[InputRDS[_]]() - @transient val outputRDSs = new ArrayBuffer[RDS[_]]() - var tempDirRoot: String = null - var tempDir: Path = null - - def readNetworkStream[T: ClassManifest]( + val inputRDSs = new ArrayBuffer[InputRDS[_]]() + val outputRDSs = new ArrayBuffer[RDS[_]]() + var batchDuration: Time = null + var scheduler: Scheduler = null + + def setBatchDuration(duration: Long) { + setBatchDuration(Time(duration)) + } + + def setBatchDuration(duration: Time) { + batchDuration = duration + } + + /* + def createNetworkStream[T: ClassManifest]( name: String, addresses: Array[InetSocketAddress], batchDuration: Time): RDS[T] = { - val inputRDS = new NetworkInputRDS[T](name, addresses, batchDuration, this) + val inputRDS = new NetworkInputRDS[T](this, addresses) inputRDSs += inputRDS inputRDS - } + } - def readNetworkStream[T: ClassManifest]( + def createNetworkStream[T: ClassManifest]( name: String, addresses: Array[String], batchDuration: Long): RDS[T] = { @@ -65,40 +73,100 @@ class SparkStreamContext ( addresses.map(stringToInetSocketAddress).toArray, LongTime(batchDuration)) } - - def readFileStream(name: String, directory: String): RDS[String] = { - val path = new Path(directory) - val fs = path.getFileSystem(new Configuration()) - val qualPath = path.makeQualified(fs) - val inputRDS = new FileInputRDS(name, qualPath.toString, this) + */ + + /** + * This function creates a input stream that monitors a Hadoop-compatible + * for new files and executes the necessary processing on them. + */ + def createFileStream[ + K: ClassManifest, + V: ClassManifest, + F <: NewInputFormat[K, V]: ClassManifest + ](directory: String): RDS[(K, V)] = { + val inputRDS = new FileInputRDS[K, V, F](this, new Path(directory)) inputRDSs += inputRDS inputRDS } - def readTestStream(name: String, batchDuration: Long): RDS[String] = { - val inputRDS = new TestInputRDS(name, LongTime(batchDuration), this) - inputRDSs += inputRDS + def createTextFileStream(directory: String): RDS[String] = { + createFileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) + } + + /** + * This function create a input stream from an queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue + */ + def createQueueStream[T: ClassManifest]( + queue: Queue[RDD[T]], + oneAtATime: Boolean = true, + defaultRDD: RDD[T] = null + ): RDS[T] = { + val inputRDS = new QueueInputRDS(this, queue, oneAtATime, defaultRDD) + inputRDSs += inputRDS inputRDS } + + def createQueueStream[T: ClassManifest](iterator: Iterator[RDD[T]]): RDS[T] = { + val queue = new Queue[RDD[T]] + val inputRDS = createQueueStream(queue, true, null) + queue ++= iterator + inputRDS + } + + /** + * This function registers a RDS as an output stream that will be + * computed every interval. + */ def registerOutputStream (outputRDS: RDS[_]) { outputRDSs += outputRDS } - - def setTempDir(dir: String) { - tempDirRoot = dir + + /** + * This function verify whether the stream computation is eligible to be executed. + */ + def verify() { + if (batchDuration == null) { + throw new Exception("Batch duration has not been set") + } + if (batchDuration < Milliseconds(100)) { + logWarning("Batch duration of " + batchDuration + " is very low") + } + if (inputRDSs.size == 0) { + throw new Exception("No input RDSes created, so nothing to take input from") + } + if (outputRDSs.size == 0) { + throw new Exception("No output RDSes registered, so nothing to execute") + } + } - - def run () { - val ctxt = this - val actor = actorSystem.actorOf( - Props(new Scheduler(ctxt, inputRDSs.toArray, outputRDSs.toArray)), - name = "SparkStreamScheduler") - logInfo("Registered actor") - actorSystem.awaitTermination() + + /** + * This function starts the execution of the streams. + */ + def start() { + verify() + scheduler = new Scheduler(this, inputRDSs.toArray, outputRDSs.toArray) + scheduler.start() + } + + /** + * This function starts the execution of the streams. + */ + def stop() { + try { + scheduler.stop() + sc.stop() + } catch { + case e: Exception => logWarning("Error while stopping", e) + } + + logInfo("SparkStreamContext stopped") } } + object SparkStreamContext { implicit def rdsToPairRdsFunctions [K: ClassManifest, V: ClassManifest] (rds: RDS[(K,V)]) = new PairRDSFunctions (rds) diff --git a/streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala b/streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala deleted file mode 100644 index 7e23b7bb82..0000000000 --- a/streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala +++ /dev/null @@ -1,42 +0,0 @@ -package spark.streaming -import spark.Logging -import scala.collection.mutable.{ArrayBuffer, HashMap} - -object TestInputBlockTracker extends Logging { - initLogging() - val allBlockIds = new HashMap[Time, ArrayBuffer[String]]() - - def addBlocks(intervalEndTime: Time, reference: AnyRef) { - allBlockIds.getOrElseUpdate(intervalEndTime, new ArrayBuffer[String]()) ++= reference.asInstanceOf[Array[String]] - } - - def setEndTime(intervalEndTime: Time) { - try { - val endTime = System.currentTimeMillis - allBlockIds.get(intervalEndTime) match { - case Some(blockIds) => { - val numBlocks = blockIds.size - var totalDelay = 0d - blockIds.foreach(blockId => { - val inputTime = getInputTime(blockId) - val delay = (endTime - inputTime) / 1000.0 - totalDelay += delay - logInfo("End-to-end delay for block " + blockId + " is " + delay + " s") - }) - logInfo("Average end-to-end delay for time " + intervalEndTime + " is " + (totalDelay / numBlocks) + " s") - allBlockIds -= intervalEndTime - } - case None => throw new Exception("Unexpected") - } - } catch { - case e: Exception => logError(e.toString) - } - } - - def getInputTime(blockId: String): Long = { - val parts = blockId.split("-") - /*logInfo(blockId + " -> " + parts(4)) */ - parts(4).toLong - } -} - diff --git a/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala b/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala index a7a5635aa5..bbf2c7bf5e 100644 --- a/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala +++ b/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala @@ -34,8 +34,8 @@ extends Thread with Logging { class DataHandler( inputName: String, - longIntervalDuration: LongTime, - shortIntervalDuration: LongTime, + longIntervalDuration: Time, + shortIntervalDuration: Time, blockManager: BlockManager ) extends Logging { @@ -61,8 +61,8 @@ extends Thread with Logging { initLogging() - val shortIntervalDurationMillis = shortIntervalDuration.asInstanceOf[LongTime].milliseconds - val longIntervalDurationMillis = longIntervalDuration.asInstanceOf[LongTime].milliseconds + val shortIntervalDurationMillis = shortIntervalDuration.toLong + val longIntervalDurationMillis = longIntervalDuration.toLong var currentBlock: Block = null var currentBucket: Bucket = null @@ -101,7 +101,7 @@ extends Thread with Logging { def updateCurrentBlock() { /*logInfo("Updating current block")*/ - val currentTime: LongTime = LongTime(System.currentTimeMillis) + val currentTime = Time(System.currentTimeMillis) val shortInterval = getShortInterval(currentTime) val longInterval = getLongInterval(shortInterval) @@ -318,12 +318,12 @@ extends Thread with Logging { val inputName = streamDetails.name val intervalDurationMillis = streamDetails.duration - val intervalDuration = LongTime(intervalDurationMillis) + val intervalDuration = Time(intervalDurationMillis) val dataHandler = new DataHandler( inputName, intervalDuration, - LongTime(TestStreamReceiver3.SHORT_INTERVAL_MILLIS), + Time(TestStreamReceiver3.SHORT_INTERVAL_MILLIS), blockManager) val connListener = new ConnectionListener(TestStreamReceiver3.PORT, dataHandler) @@ -382,7 +382,7 @@ extends Thread with Logging { def waitFor(time: Time) { val currentTimeMillis = System.currentTimeMillis - val targetTimeMillis = time.asInstanceOf[LongTime].milliseconds + val targetTimeMillis = time.milliseconds if (currentTimeMillis < targetTimeMillis) { val sleepTime = (targetTimeMillis - currentTimeMillis) Thread.sleep(sleepTime + 1) @@ -392,7 +392,7 @@ extends Thread with Logging { def notifyScheduler(interval: Interval, blockIds: Array[String]) { try { sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) - val time = interval.endTime.asInstanceOf[LongTime] + val time = interval.endTime val delay = (System.currentTimeMillis - time.milliseconds) / 1000.0 logInfo("Pushing delay for " + time + " is " + delay + " s") } catch { diff --git a/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala b/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala index 2c3f5d1b9d..a2babb23f4 100644 --- a/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala +++ b/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala @@ -24,8 +24,8 @@ extends Thread with Logging { class DataHandler( inputName: String, - longIntervalDuration: LongTime, - shortIntervalDuration: LongTime, + longIntervalDuration: Time, + shortIntervalDuration: Time, blockManager: BlockManager ) extends Logging { @@ -50,8 +50,8 @@ extends Thread with Logging { val syncOnLastShortInterval = true - val shortIntervalDurationMillis = shortIntervalDuration.asInstanceOf[LongTime].milliseconds - val longIntervalDurationMillis = longIntervalDuration.asInstanceOf[LongTime].milliseconds + val shortIntervalDurationMillis = shortIntervalDuration.milliseconds + val longIntervalDurationMillis = longIntervalDuration.milliseconds val buffer = ByteBuffer.allocateDirect(100 * 1024 * 1024) var currentShortInterval = Interval.currentInterval(shortIntervalDuration) @@ -145,7 +145,7 @@ extends Thread with Logging { if (syncOnLastShortInterval) { bucket += newBlock } - logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.asInstanceOf[LongTime].milliseconds) / 1000.0 + " s" ) + logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.milliseconds) / 1000.0 + " s" ) blockPushingExecutor.execute(new Runnable() { def run() { pushAndNotifyBlock(newBlock) } }) } } @@ -175,7 +175,7 @@ extends Thread with Logging { try{ if (blockManager != null) { val startTime = System.currentTimeMillis - logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.asInstanceOf[LongTime].milliseconds) + " ms") + logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.milliseconds) + " ms") /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY)*/ /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_2)*/ blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY_2) @@ -343,7 +343,7 @@ extends Thread with Logging { def waitFor(time: Time) { val currentTimeMillis = System.currentTimeMillis - val targetTimeMillis = time.asInstanceOf[LongTime].milliseconds + val targetTimeMillis = time.milliseconds if (currentTimeMillis < targetTimeMillis) { val sleepTime = (targetTimeMillis - currentTimeMillis) Thread.sleep(sleepTime + 1) @@ -353,7 +353,7 @@ extends Thread with Logging { def notifyScheduler(interval: Interval, blockIds: Array[String]) { try { sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) - val time = interval.endTime.asInstanceOf[LongTime] + val time = interval.endTime val delay = (System.currentTimeMillis - time.milliseconds) logInfo("Notification delay for " + time + " is " + delay + " ms") } catch { diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index b932fe9258..c4573137ae 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -1,19 +1,34 @@ package spark.streaming -abstract case class Time { +class Time(private var millis: Long) { - // basic operations that must be overridden - def copy(): Time - def zero: Time - def < (that: Time): Boolean - def += (that: Time): Time - def -= (that: Time): Time - def floor(that: Time): Time - def isMultipleOf(that: Time): Boolean + def copy() = new Time(this.millis) + + def zero = Time.zero + + def < (that: Time): Boolean = + (this.millis < that.millis) + + def <= (that: Time) = (this < that || this == that) + + def > (that: Time) = !(this <= that) + + def >= (that: Time) = !(this < that) + + def += (that: Time): Time = { + this.millis += that.millis + this + } + + def -= (that: Time): Time = { + this.millis -= that.millis + this + } - // derived operations composed of basic operations def + (that: Time) = this.copy() += that + def - (that: Time) = this.copy() -= that + def * (times: Int) = { var count = 0 var result = this.copy() @@ -23,63 +38,44 @@ abstract case class Time { } result } - def <= (that: Time) = (this < that || this == that) - def > (that: Time) = !(this <= that) - def >= (that: Time) = !(this < that) - def isZero = (this == zero) - def toFormattedString = toString -} - -object Time { - def Milliseconds(milliseconds: Long) = LongTime(milliseconds) - - def zero = LongTime(0) -} - -case class LongTime(var milliseconds: Long) extends Time { - - override def copy() = LongTime(this.milliseconds) - - override def zero = LongTime(0) - - override def < (that: Time): Boolean = - (this.milliseconds < that.asInstanceOf[LongTime].milliseconds) - - override def += (that: Time): Time = { - this.milliseconds += that.asInstanceOf[LongTime].milliseconds - this - } - override def -= (that: Time): Time = { - this.milliseconds -= that.asInstanceOf[LongTime].milliseconds - this + def floor(that: Time): Time = { + val t = that.millis + val m = math.floor(this.millis / t).toLong + new Time(m * t) } - override def floor(that: Time): Time = { - val t = that.asInstanceOf[LongTime].milliseconds - val m = this.milliseconds / t - LongTime(m.toLong * t) - } + def isMultipleOf(that: Time): Boolean = + (this.millis % that.millis == 0) - override def isMultipleOf(that: Time): Boolean = - (this.milliseconds % that.asInstanceOf[LongTime].milliseconds == 0) + def isZero = (this.millis == 0) - override def isZero = (this.milliseconds == 0) + override def toString() = (millis.toString + " ms") - override def toString = (milliseconds.toString + "ms") + def toFormattedString() = millis.toString + + def milliseconds() = millis +} - override def toFormattedString = milliseconds.toString +object Time { + val zero = new Time(0) + + def apply(milliseconds: Long) = new Time(milliseconds) + + implicit def toTime(long: Long) = Time(long) + + implicit def toLong(time: Time) = time.milliseconds } object Milliseconds { - def apply(milliseconds: Long) = LongTime(milliseconds) + def apply(milliseconds: Long) = Time(milliseconds) } object Seconds { - def apply(seconds: Long) = LongTime(seconds * 1000) + def apply(seconds: Long) = Time(seconds * 1000) } object Minutes { - def apply(minutes: Long) = LongTime(minutes * 60000) + def apply(minutes: Long) = Time(minutes * 60000) } diff --git a/streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala deleted file mode 100644 index 2ca72da79f..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala +++ /dev/null @@ -1,138 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.JavaConversions.mapAsScalaMap -import scala.collection.mutable.Queue - -import java.lang.{Long => JLong} - -object DumbTopKWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - wordCounts.persist(StorageLevel.MEMORY_ONLY) - val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) - - def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { - val taken = new Array[(String, JLong)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, JLong) = null - var swap: (String, JLong) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - /*println("count = " + count)*/ - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 10 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala deleted file mode 100644 index 34e7edfda9..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala +++ /dev/null @@ -1,92 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.collection.JavaConversions.mapAsScalaMap - -import java.lang.{Long => JLong} -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - -object DumbWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - - map.toIterator - } - - val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - wordCounts.persist(StorageLevel.MEMORY_ONLY) - val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala b/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala new file mode 100644 index 0000000000..d56fdcdf29 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala @@ -0,0 +1,41 @@ +package spark.streaming.examples + +import spark.RDD +import spark.streaming.SparkStreamContext +import spark.streaming.SparkStreamContext._ +import spark.streaming.Seconds + +import scala.collection.mutable.SynchronizedQueue + +object ExampleOne { + + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: ExampleOne ") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new SparkStreamContext(args(0), "ExampleOne") + ssc.setBatchDuration(Seconds(1)) + + // Create the queue through which RDDs can be pushed to + // a QueueInputRDS + val rddQueue = new SynchronizedQueue[RDD[Int]]() + + // Create the QueueInputRDs and use it do some processing + val inputStream = ssc.createQueueStream(rddQueue) + val mappedStream = inputStream.map(x => (x % 10, 1)) + val reducedStream = mappedStream.reduceByKey(_ + _) + reducedStream.print() + ssc.start() + + // Create and push some RDDs into + for (i <- 1 to 30) { + rddQueue += ssc.sc.makeRDD(1 to 1000, 10) + Thread.sleep(1000) + } + ssc.stop() + System.exit(0) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala b/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala new file mode 100644 index 0000000000..4b8f6d609d --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala @@ -0,0 +1,47 @@ +package spark.streaming.examples + +import spark.streaming.SparkStreamContext +import spark.streaming.SparkStreamContext._ +import spark.streaming.Seconds +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + + +object ExampleTwo { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: ExampleOne ") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new SparkStreamContext(args(0), "ExampleTwo") + ssc.setBatchDuration(Seconds(2)) + + // Create the new directory + val directory = new Path(args(1)) + val fs = directory.getFileSystem(new Configuration()) + if (fs.exists(directory)) throw new Exception("This directory already exists") + fs.mkdirs(directory) + + // Create the FileInputRDS on the directory and use the + // stream to count words in new files created + val inputRDS = ssc.createTextFileStream(directory.toString) + val wordsRDS = inputRDS.flatMap(_.split(" ")) + val wordCountsRDS = wordsRDS.map(x => (x, 1)).reduceByKey(_ + _) + wordCountsRDS.print + ssc.start() + + // Creating new files in the directory + val text = "This is a text file" + for (i <- 1 to 30) { + ssc.sc.parallelize((1 to (i * 10)).map(_ => text), 10) + .saveAsTextFile(new Path(directory, i.toString).toString) + Thread.sleep(1000) + } + Thread.sleep(5000) // Waiting for the file to be processed + ssc.stop() + fs.delete(directory) + System.exit(0) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepCount.scala b/streaming/src/main/scala/spark/streaming/examples/GrepCount.scala deleted file mode 100644 index ec3e70f258..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/GrepCount.scala +++ /dev/null @@ -1,39 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object GrepCount { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: GrepCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "GrepCount") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - val matching = sentences.filter(_.contains("light")) - matching.foreachRDD(rdd => println(rdd.count)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala b/streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala deleted file mode 100644 index 27ecced1c0..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala +++ /dev/null @@ -1,113 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkEnv -import spark.SparkContext -import spark.storage.StorageLevel -import spark.network.Message -import spark.network.ConnectionManagerId - -import java.nio.ByteBuffer - -object GrepCount2 { - - def startSparkEnvs(sc: SparkContext) { - - val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) - sc.runJob(dummy, (_: Iterator[Int]) => {}) - - println("SparkEnvs started") - Thread.sleep(1000) - /*sc.runJob(sc.parallelize(0 to 1000, 100), (_: Iterator[Int]) => {})*/ - } - - def warmConnectionManagers(sc: SparkContext) { - val slaveConnManagerIds = sc.parallelize(0 to 100, 100).map( - i => SparkEnv.get.connectionManager.id).collect().distinct - println("\nSlave ConnectionManagerIds") - slaveConnManagerIds.foreach(println) - println - - Thread.sleep(1000) - val numSlaves = slaveConnManagerIds.size - val count = 3 - val size = 5 * 1024 * 1024 - val iterations = (500 * 1024 * 1024 / (numSlaves * size)).toInt - println("count = " + count + ", size = " + size + ", iterations = " + iterations) - - (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until numSlaves, numSlaves).map(i => { - val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - /*connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - })*/ - - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - val futures = (0 until iterations).map(i => { - slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - println("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") - connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) - }) - }).flatMap(x => x) - val results = futures.map(f => f()) - val finishTime = System.currentTimeMillis - - - val mb = size * results.size / 1024.0 / 1024.0 - val ms = finishTime - startTime - - val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" - println(resultStr) - System.gc() - resultStr - }).collect() - - println("---------------------") - println("Run " + i) - resultStrs.foreach(println) - println("---------------------") - }) - } - - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: GrepCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "GrepCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - /*startSparkEnvs(ssc.sc)*/ - warmConnectionManagers(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-"+i, 500)).toArray - ) - - val matching = sentences.filter(_.contains("light")) - matching.foreachRDD(rdd => println(rdd.count)) - - ssc.run - } -} - - - - diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala b/streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala deleted file mode 100644 index f9674136fe..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala +++ /dev/null @@ -1,54 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object GrepCountApprox { - var inputFile : String = null - var hdfs : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 5) { - println ("Usage: GrepCountApprox ") - System.exit(1) - } - - hdfs = args(1) - inputFile = hdfs + args(2) - idealPartitions = args(3).toInt - val timeout = args(4).toLong - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "GrepCount") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - ssc.setTempDir(hdfs + "/tmp") - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - val matching = sentences.filter(_.contains("light")) - var i = 0 - val startTime = System.currentTimeMillis - matching.foreachRDD { rdd => - val myNum = i - val result = rdd.countApprox(timeout) - val initialTime = (System.currentTimeMillis - startTime) / 1000.0 - printf("APPROX\t%.2f\t%d\tinitial\t%.1f\t%.1f\n", initialTime, myNum, result.initialValue.mean, - result.initialValue.high - result.initialValue.low) - result.onComplete { r => - val finalTime = (System.currentTimeMillis - startTime) / 1000.0 - printf("APPROX\t%.2f\t%d\tfinal\t%.1f\t0.0\t%.1f\n", finalTime, myNum, r.mean, finalTime - initialTime) - } - i += 1 - } - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala deleted file mode 100644 index a75ccd3a56..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala +++ /dev/null @@ -1,30 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -object SimpleWordCount { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1) - counts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala deleted file mode 100644 index 9672e64b13..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala +++ /dev/null @@ -1,51 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import scala.util.Sorting - -object SimpleWordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SimpleWordCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - - val words = sentences.flatMap(_.split(" ")) - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10) - counts.foreachRDD(_.collect()) - /*words.foreachRDD(_.countByValue())*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala deleted file mode 100644 index 503033a8e5..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala +++ /dev/null @@ -1,83 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import scala.collection.JavaConversions.mapAsScalaMap -import scala.util.Sorting -import java.lang.{Long => JLong} - -object SimpleWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SimpleWordCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 400)).toArray - ) - - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - /*val words = sentences.flatMap(_.split(" "))*/ - /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10)*/ - val counts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - counts.foreachRDD(_.collect()) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala b/streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala deleted file mode 100644 index 031e989c87..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala +++ /dev/null @@ -1,97 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object TopContentCount { - - case class Event(val country: String, val content: String) - - object Event { - def create(string: String): Event = { - val parts = string.split(":") - new Event(parts(0), parts(1)) - } - } - - def main(args: Array[String]) { - - if (args.length < 2) { - println ("Usage: GrepCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "TopContentCount") - val sc = ssc.sc - val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) - sc.runJob(dummy, (_: Iterator[Int]) => {}) - - - val numEventStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - val eventStrings = new UnifiedRDS( - (1 to numEventStreams).map(i => ssc.readTestStream("Events-" + i, 1000)).toArray - ) - - def parse(string: String) = { - val parts = string.split(":") - (parts(0), parts(1)) - } - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val events = eventStrings.map(x => parse(x)) - /*events.print*/ - - val parallelism = 8 - val counts_per_content_per_country = events - .map(x => (x, 1)) - .reduceByKey(_ + _) - /*.reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), parallelism)*/ - /*counts_per_content_per_country.print*/ - - /* - counts_per_content_per_country.persist( - StorageLevel.MEMORY_ONLY_DESER, - StorageLevel.MEMORY_ONLY_DESER_2, - Seconds(1) - )*/ - - val counts_per_country = counts_per_content_per_country - .map(x => (x._1._1, (x._1._2, x._2))) - .groupByKey() - counts_per_country.print - - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - val taken = array.take(k) - taken - } - - val k = 10 - val topKContents_per_country = counts_per_country - .map(x => (x._1, topK(x._2, k))) - .map(x => (x._1, x._2.map(_.toString).reduceLeft(_ + ", " + _))) - - topKContents_per_country.print - - ssc.run - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala deleted file mode 100644 index 679ed0a7ef..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala +++ /dev/null @@ -1,103 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object TopKWordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - - def topK(data: Iterator[(String, Int)], k: Int): Iterator[(String, Int)] = { - val taken = new Array[(String, Int)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, Int) = null - var swap: (String, Int) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - println("count = " + count) - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 10 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala deleted file mode 100644 index c873fbd0f0..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala +++ /dev/null @@ -1,142 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.JavaConversions.mapAsScalaMap -import scala.collection.mutable.Queue - -import java.lang.{Long => JLong} - -object TopKWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "TopKWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray - ) - - /*val words = sentences.flatMap(_.split(" "))*/ - - /*def add(v1: Int, v2: Int) = (v1 + v2) */ - /*def subtract(v1: Int, v2: Int) = (v1 - v2) */ - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - val windowedCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Milliseconds(500), 10) - /*windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1))*/ - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) - - def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { - val taken = new Array[(String, JLong)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, JLong) = null - var swap: (String, JLong) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - println("count = " + count) - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 50 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount.scala deleted file mode 100644 index fb5508ffcc..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala +++ /dev/null @@ -1,62 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordCount { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), - // System.getProperty("spark.default.parallelism", "1").toInt) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) - //windowedCounts.print - - val parallelism = System.getProperty("spark.default.parallelism", "1").toInt - - //val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) - //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) - //val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(_ + _) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), - parallelism) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(10)) - - //windowedCounts.print - windowedCounts.register - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount1.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount1.scala deleted file mode 100644 index 42d985920a..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount1.scala +++ /dev/null @@ -1,46 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordCount1 { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala deleted file mode 100644 index 9168a2fe2f..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ /dev/null @@ -1,55 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object WordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 6) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala deleted file mode 100644 index 1920915af7..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala +++ /dev/null @@ -1,94 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.collection.JavaConversions.mapAsScalaMap - -import java.lang.{Long => JLong} -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - - -object WordCount2_ExtraFunctions { - - def add(v1: JLong, v2: JLong) = (v1 + v2) - - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } -} - -object WordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 40).foreach {i => - sc.parallelize(1 to 20000000, 1000) - .map(_ % 1331).map(_.toString) - .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray - ) - - val windowedCounts = sentences - .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions) - .reduceByKeyAndWindow(WordCount2_ExtraFunctions.add _, WordCount2_ExtraFunctions.subtract _, Seconds(10), Milliseconds(500), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount3.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount3.scala deleted file mode 100644 index 018c19a509..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount3.scala +++ /dev/null @@ -1,49 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -object WordCount3 { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - /*val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1)*/ - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), 1) - /*windowedCounts.print */ - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - array.take(k) - } - - val k = 10 - val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) - topKWindowedCounts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala deleted file mode 100644 index 82b9fa781d..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala +++ /dev/null @@ -1,41 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ -import spark.SparkContext - -object WordCountEc2 { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: SparkStreamContext ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val ssc = new SparkStreamContext(args(0), "Test") - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.foreach(println)*/ - - val words = sentences.flatMap(_.split(" ")) - /*words.foreach(println)*/ - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _) - /*counts.foreach(println)*/ - - counts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - /*counts.register*/ - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala deleted file mode 100644 index 114dd144f1..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala +++ /dev/null @@ -1,51 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -object WordCountTrivialWindow { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCountTrivialWindow") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1)*/ - /*counts.print*/ - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1) - /*windowedCounts.print */ - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - array.take(k) - } - - val k = 10 - val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) - topKWindowedCounts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax.scala deleted file mode 100644 index fbfc48030f..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordMax.scala +++ /dev/null @@ -1,64 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordMax { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - def max(v1: Int, v2: Int) = (if (v1 > v2) v1 else v2) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), - // System.getProperty("spark.default.parallelism", "1").toInt) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) - //windowedCounts.print - - val parallelism = System.getProperty("spark.default.parallelism", "1").toInt - - val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) - //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER) - localCounts.persist(StorageLevel.MEMORY_ONLY_DESER_2) - val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(max _) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), - // parallelism) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) - - //windowedCounts.print - windowedCounts.register - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala new file mode 100644 index 0000000000..6125bb82eb --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -0,0 +1,52 @@ +package spark.streaming.util + +class RecurringTimer(period: Long, callback: (Long) => Unit) { + + val minPollTime = 25L + + val pollTime = { + if (period / 10.0 > minPollTime) { + (period / 10.0).toLong + } else { + minPollTime + } + } + + val thread = new Thread() { + override def run() { loop } + } + + var nextTime = 0L + + def start(): Long = { + nextTime = (math.floor(System.currentTimeMillis() / period) + 1).toLong * period + thread.start() + nextTime + } + + def stop() { + thread.interrupt() + } + + def loop() { + try { + while (true) { + val beforeSleepTime = System.currentTimeMillis() + while (beforeSleepTime >= nextTime) { + callback(nextTime) + nextTime += period + } + val sleepTime = if (nextTime - beforeSleepTime < 2 * pollTime) { + nextTime - beforeSleepTime + } else { + pollTime + } + Thread.sleep(sleepTime) + val afterSleepTime = System.currentTimeMillis() + } + } catch { + case e: InterruptedException => + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala new file mode 100644 index 0000000000..9925b1d07c --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala @@ -0,0 +1,64 @@ +package spark.streaming.util + +import java.net.{Socket, ServerSocket} +import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream} + +object Receiver { + def main(args: Array[String]) { + val port = args(0).toInt + val lsocket = new ServerSocket(port) + println("Listening on port " + port ) + while(true) { + val socket = lsocket.accept() + (new Thread() { + override def run() { + val buffer = new Array[Byte](100000) + var count = 0 + val time = System.currentTimeMillis + try { + val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) + var loop = true + var string: String = null + while((string = is.readUTF) != null) { + count += 28 + } + } catch { + case e: Exception => e.printStackTrace + } + val timeTaken = System.currentTimeMillis - time + val tput = (count / 1024.0) / (timeTaken / 1000.0) + println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s") + } + }).start() + } + } + +} + +object Sender { + + def main(args: Array[String]) { + try { + val host = args(0) + val port = args(1).toInt + val size = args(2).toInt + + val byteStream = new ByteArrayOutputStream() + val stringDataStream = new DataOutputStream(byteStream) + (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy")) + val bytes = byteStream.toByteArray() + println("Generated array of " + bytes.length + " bytes") + + /*val bytes = new Array[Byte](size)*/ + val socket = new Socket(host, port) + val os = socket.getOutputStream + os.write(bytes) + os.flush + socket.close() + + } catch { + case e: Exception => e.printStackTrace + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala new file mode 100644 index 0000000000..94e8f7a849 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala @@ -0,0 +1,92 @@ +package spark.streaming.util + +import spark._ + +import scala.collection.mutable.ArrayBuffer +import scala.util.Random +import scala.io.Source + +import java.net.InetSocketAddress + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +object SentenceFileGenerator { + + def printUsage () { + println ("Usage: SentenceFileGenerator <# partitions> []") + System.exit(0) + } + + def main (args: Array[String]) { + if (args.length < 4) { + printUsage + } + + val master = args(0) + val fs = new Path(args(1)).getFileSystem(new Configuration()) + val targetDirectory = new Path(args(1)).makeQualified(fs) + val numPartitions = args(2).toInt + val sentenceFile = args(3) + val sentencesPerSecond = { + if (args.length > 4) args(4).toInt + else 10 + } + + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n").toArray + source.close () + println("Read " + lines.length + " lines from file " + sentenceFile) + + val sentences = { + val buffer = ArrayBuffer[String]() + val random = new Random() + var i = 0 + while (i < sentencesPerSecond) { + buffer += lines(random.nextInt(lines.length)) + i += 1 + } + buffer.toArray + } + println("Generated " + sentences.length + " sentences") + + val sc = new SparkContext(master, "SentenceFileGenerator") + val sentencesRDD = sc.parallelize(sentences, numPartitions) + + val tempDirectory = new Path(targetDirectory, "_tmp") + + fs.mkdirs(targetDirectory) + fs.mkdirs(tempDirectory) + + var saveTimeMillis = System.currentTimeMillis + try { + while (true) { + val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis) + val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis) + println("Writing to file " + newDir) + sentencesRDD.saveAsTextFile(tmpNewDir.toString) + fs.rename(tmpNewDir, newDir) + saveTimeMillis += 1000 + val sleepTimeMillis = { + val currentTimeMillis = System.currentTimeMillis + if (saveTimeMillis < currentTimeMillis) { + 0 + } else { + saveTimeMillis - currentTimeMillis + } + } + println("Sleeping for " + sleepTimeMillis + " ms") + Thread.sleep(sleepTimeMillis) + } + } catch { + case e: Exception => + } + } +} + + + + diff --git a/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala b/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala new file mode 100644 index 0000000000..60085f4f88 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala @@ -0,0 +1,23 @@ +package spark.streaming.util + +import spark.SparkContext +import SparkContext._ + +object ShuffleTest { + def main(args: Array[String]) { + + if (args.length < 1) { + println ("Usage: ShuffleTest ") + System.exit(1) + } + + val sc = new spark.SparkContext(args(0), "ShuffleTest") + val rdd = sc.parallelize(1 to 1000, 500).cache + + def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) } + + time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } } + System.exit(0) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/util/Utils.scala b/streaming/src/main/scala/spark/streaming/util/Utils.scala new file mode 100644 index 0000000000..86a729fb49 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/Utils.scala @@ -0,0 +1,9 @@ +package spark.streaming.util + +object Utils { + def time(func: => Unit): Long = { + val t = System.currentTimeMillis + func + (System.currentTimeMillis - t) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala b/streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala deleted file mode 100644 index bb32089ae2..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala +++ /dev/null @@ -1,78 +0,0 @@ -package spark.streaming - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.net.InetSocketAddress - -/*import akka.actor.Actor._*/ -/*import akka.actor.ActorRef*/ - - -object SenGeneratorForPerformanceTest { - - def printUsage () { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - - def main (args: Array[String]) { - if (args.length < 3) { - printUsage - } - - val inputManagerIP = args(0) - val inputManagerPort = args(1).toInt - val sentenceFile = args(2) - val sentencesPerSecond = { - if (args.length > 3) args(3).toInt - else 10 - } - - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close () - - try { - /*val inputManager = remote.actorFor("InputReceiver-Sentences",*/ - /* inputManagerIP, inputManagerPort)*/ - val inputManager = select(Node(inputManagerIP, inputManagerPort), Symbol("InputReceiver-Sentences")) - val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 - val random = new Random () - println ("Sending " + sentencesPerSecond + " sentences per second to " + inputManagerIP + ":" + inputManagerPort) - var lastPrintTime = System.currentTimeMillis() - var count = 0 - - while (true) { - /*if (!inputManager.tryTell (lines (random.nextInt (lines.length))))*/ - /*throw new Exception ("disconnected")*/ -// inputManager ! lines (random.nextInt (lines.length)) - for (i <- 0 to sentencesPerSecond) inputManager ! lines (0) - println(System.currentTimeMillis / 1000 + " s") -/* count += 1 - - if (System.currentTimeMillis - lastPrintTime >= 1000) { - println (count + " sentences sent last second") - count = 0 - lastPrintTime = System.currentTimeMillis - } - - Thread.sleep (sleepBetweenSentences.toLong) -*/ - val currentMs = System.currentTimeMillis / 1000; - Thread.sleep ((currentMs * 1000 + 1000) - System.currentTimeMillis) - } - } catch { - case e: Exception => - /*Thread.sleep (1000)*/ - } - } -} - - - - diff --git a/streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala deleted file mode 100644 index 6af270298a..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala +++ /dev/null @@ -1,63 +0,0 @@ -package spark.streaming -import java.net.{Socket, ServerSocket} -import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream} - -object Receiver { - def main(args: Array[String]) { - val port = args(0).toInt - val lsocket = new ServerSocket(port) - println("Listening on port " + port ) - while(true) { - val socket = lsocket.accept() - (new Thread() { - override def run() { - val buffer = new Array[Byte](100000) - var count = 0 - val time = System.currentTimeMillis - try { - val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) - var loop = true - var string: String = null - while((string = is.readUTF) != null) { - count += 28 - } - } catch { - case e: Exception => e.printStackTrace - } - val timeTaken = System.currentTimeMillis - time - val tput = (count / 1024.0) / (timeTaken / 1000.0) - println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s") - } - }).start() - } - } - -} - -object Sender { - - def main(args: Array[String]) { - try { - val host = args(0) - val port = args(1).toInt - val size = args(2).toInt - - val byteStream = new ByteArrayOutputStream() - val stringDataStream = new DataOutputStream(byteStream) - (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy")) - val bytes = byteStream.toByteArray() - println("Generated array of " + bytes.length + " bytes") - - /*val bytes = new Array[Byte](size)*/ - val socket = new Socket(host, port) - val os = socket.getOutputStream - os.write(bytes) - os.flush - socket.close() - - } catch { - case e: Exception => e.printStackTrace - } - } -} - diff --git a/streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala deleted file mode 100644 index 15858f59e3..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala +++ /dev/null @@ -1,92 +0,0 @@ -package spark.streaming - -import spark._ - -import scala.collection.mutable.ArrayBuffer -import scala.util.Random -import scala.io.Source - -import java.net.InetSocketAddress - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -object SentenceFileGenerator { - - def printUsage () { - println ("Usage: SentenceFileGenerator <# partitions> []") - System.exit(0) - } - - def main (args: Array[String]) { - if (args.length < 4) { - printUsage - } - - val master = args(0) - val fs = new Path(args(1)).getFileSystem(new Configuration()) - val targetDirectory = new Path(args(1)).makeQualified(fs) - val numPartitions = args(2).toInt - val sentenceFile = args(3) - val sentencesPerSecond = { - if (args.length > 4) args(4).toInt - else 10 - } - - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n").toArray - source.close () - println("Read " + lines.length + " lines from file " + sentenceFile) - - val sentences = { - val buffer = ArrayBuffer[String]() - val random = new Random() - var i = 0 - while (i < sentencesPerSecond) { - buffer += lines(random.nextInt(lines.length)) - i += 1 - } - buffer.toArray - } - println("Generated " + sentences.length + " sentences") - - val sc = new SparkContext(master, "SentenceFileGenerator") - val sentencesRDD = sc.parallelize(sentences, numPartitions) - - val tempDirectory = new Path(targetDirectory, "_tmp") - - fs.mkdirs(targetDirectory) - fs.mkdirs(tempDirectory) - - var saveTimeMillis = System.currentTimeMillis - try { - while (true) { - val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis) - val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis) - println("Writing to file " + newDir) - sentencesRDD.saveAsTextFile(tmpNewDir.toString) - fs.rename(tmpNewDir, newDir) - saveTimeMillis += 1000 - val sleepTimeMillis = { - val currentTimeMillis = System.currentTimeMillis - if (saveTimeMillis < currentTimeMillis) { - 0 - } else { - saveTimeMillis - currentTimeMillis - } - } - println("Sleeping for " + sleepTimeMillis + " ms") - Thread.sleep(sleepTimeMillis) - } - } catch { - case e: Exception => - } - } -} - - - - diff --git a/streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala b/streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala deleted file mode 100644 index a9f124d2d7..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala +++ /dev/null @@ -1,103 +0,0 @@ -package spark.streaming - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.net.InetSocketAddress - - -object SentenceGenerator { - - def printUsage { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - - def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 - val random = new Random () - - try { - var lastPrintTime = System.currentTimeMillis() - var count = 0 - while(true) { - streamReceiver ! lines(random.nextInt(lines.length)) - count += 1 - if (System.currentTimeMillis - lastPrintTime >= 1000) { - println (count + " sentences sent last second") - count = 0 - lastPrintTime = System.currentTimeMillis - } - Thread.sleep(sleepBetweenSentences.toLong) - } - } catch { - case e: Exception => - } - } - - def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - try { - val numSentences = if (sentencesPerSecond <= 0) { - lines.length - } else { - sentencesPerSecond - } - var nextSendingTime = System.currentTimeMillis() - val pingInterval = if (System.getenv("INTERVAL") != null) { - System.getenv("INTERVAL").toInt - } else { - 2000 - } - while(true) { - (0 until numSentences).foreach(i => { - streamReceiver ! lines(i % lines.length) - }) - println ("Sent " + numSentences + " sentences") - nextSendingTime += pingInterval - val sleepTime = nextSendingTime - System.currentTimeMillis - if (sleepTime > 0) { - println ("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } - } - } catch { - case e: Exception => - } - } - - def main(args: Array[String]) { - if (args.length < 3) { - printUsage - } - - val generateRandomly = false - - val streamReceiverIP = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 - val sentenceInputName = if (args.length > 4) args(4) else "Sentences" - - println("Sending " + sentencesPerSecond + " sentences per second to " + - streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close () - - val streamReceiver = select( - Node(streamReceiverIP, streamReceiverPort), - Symbol("NetworkStreamReceiver-" + sentenceInputName)) - if (generateRandomly) { - generateRandomSentences(lines, sentencesPerSecond, streamReceiver) - } else { - generateSameSentences(lines, sentencesPerSecond, streamReceiver) - } - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala b/streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala deleted file mode 100644 index 32aa4144a0..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala +++ /dev/null @@ -1,22 +0,0 @@ -package spark.streaming -import spark.SparkContext -import SparkContext._ - -object ShuffleTest { - def main(args: Array[String]) { - - if (args.length < 1) { - println ("Usage: ShuffleTest ") - System.exit(1) - } - - val sc = new spark.SparkContext(args(0), "ShuffleTest") - val rdd = sc.parallelize(1 to 1000, 500).cache - - def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) } - - time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } } - System.exit(0) - } -} - diff --git a/streaming/src/test/scala/spark/streaming/RDSSuite.scala b/streaming/src/test/scala/spark/streaming/RDSSuite.scala new file mode 100644 index 0000000000..f51ea50a5d --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/RDSSuite.scala @@ -0,0 +1,65 @@ +package spark.streaming + +import spark.RDD + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.SynchronizedQueue + +class RDSSuite extends FunSuite with BeforeAndAfter { + + var ssc: SparkStreamContext = null + val batchDurationMillis = 1000 + + def testOp[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: RDS[U] => RDS[V], + expectedOutput: Seq[Seq[V]]) = { + try { + ssc = new SparkStreamContext("local", "test") + ssc.setBatchDuration(Milliseconds(batchDurationMillis)) + + val inputStream = ssc.createQueueStream(input.map(ssc.sc.makeRDD(_, 2)).toIterator) + val outputStream = operation(inputStream) + val outputQueue = outputStream.toQueue + + ssc.start() + Thread.sleep(batchDurationMillis * input.size) + + val output = new ArrayBuffer[Seq[V]]() + while(outputQueue.size > 0) { + val rdd = outputQueue.take() + println("Collecting RDD " + rdd.id + ", " + rdd.getClass().getSimpleName() + ", " + rdd.splits.size) + output += (rdd.collect()) + } + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i).toList === expectedOutput(i).toList) + } + } finally { + ssc.stop() + } + } + + test("basic operations") { + val inputData = Array(1 to 4, 5 to 8, 9 to 12) + + // map + testOp(inputData, (r: RDS[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) + + // flatMap + testOp(inputData, (r: RDS[Int]) => r.flatMap(x => Array(x, x * 2)), + inputData.map(_.flatMap(x => Array(x, x * 2))) + ) + } +} + +object RDSSuite { + def main(args: Array[String]) { + val r = new RDSSuite() + val inputData = Array(1 to 4, 5 to 8, 9 to 12) + r.testOp(inputData, (r: RDS[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) + } +} \ No newline at end of file -- cgit v1.2.3 From 741899b21e4e6439459fcf4966076661c851ed07 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 16:26:06 -0700 Subject: Fix sendMessageReliablySync --- core/src/main/scala/spark/network/ConnectionManager.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 1a22d06cc8..66b822117f 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -14,7 +14,8 @@ import scala.collection.mutable.SynchronizedQueue import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer -import akka.dispatch.{Promise, ExecutionContext, Future} +import akka.dispatch.{Await, Promise, ExecutionContext, Future} +import akka.util.Duration case class ConnectionManagerId(host: String, port: Int) { def toSocketAddress() = new InetSocketAddress(host, port) @@ -325,7 +326,7 @@ class ConnectionManager(port: Int) extends Logging { } def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = { - sendMessageReliably(connectionManagerId, message)() + Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) } def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { -- cgit v1.2.3 From 06ef7c3d1bf8446d4d6ef8f3a055dd1e6d32ca3a Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 16:29:20 -0700 Subject: Less debug info --- core/src/main/scala/spark/network/ConnectionManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 66b822117f..bd0980029a 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -248,7 +248,7 @@ class ConnectionManager(port: Int) extends Logging { } private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { - logInfo("Handling [" + message + "] from [" + connectionManagerId + "]") + logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") message match { case bufferMessage: BufferMessage => { if (bufferMessage.hasAckId) { -- cgit v1.2.3 From 29e83f39e90b4d3cbeeb40d5ec0c19bd003c1840 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 18:16:25 -0700 Subject: Fix replication with MEMORY_ONLY_DESER_2 --- core/src/main/scala/spark/storage/BlockManager.scala | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index ff9914ae25..45f99717bc 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -364,6 +364,12 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val startTimeMs = System.currentTimeMillis var bytes: ByteBuffer = null + + // If we need to replicate the data, we'll want access to the values, but because our + // put will read the whole iterator, there will be no values left. For the case where + // the put serializes data, we'll remember the bytes, above; but for the case where + // it doesn't, such as MEMORY_ONLY_DESER, let's rely on the put returning an Iterator. + var valuesAfterPut: Iterator[Any] = null locker.getLock(blockId).synchronized { logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) @@ -391,7 +397,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // If only save to memory memoryStore.putValues(blockId, values, level) match { case Right(newBytes) => bytes = newBytes - case _ => + case Left(newIterator) => valuesAfterPut = newIterator } } else { // If only save to disk @@ -408,8 +414,13 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Replicate block if required if (level.replication > 1) { + // Serialize the block if not already done if (bytes == null) { - bytes = dataSerialize(values) // serialize the block if not already done + if (valuesAfterPut == null) { + throw new SparkException( + "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") + } + bytes = dataSerialize(valuesAfterPut) } replicate(blockId, bytes, level) } -- cgit v1.2.3 From 26dfd20c9a5139bd682a9902267b9d54a11ae20f Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 18:56:56 -0700 Subject: Detect disconnected slaves in StandaloneScheduler --- .../cluster/StandaloneSchedulerBackend.scala | 38 ++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 013671c1c8..83e7c6e036 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -2,13 +2,14 @@ package spark.scheduler.cluster import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import akka.actor.{Props, Actor, ActorRef, ActorSystem} +import akka.actor._ import akka.util.duration._ import akka.pattern.ask import spark.{SparkException, Logging, TaskState} import akka.dispatch.Await import java.util.concurrent.atomic.AtomicInteger +import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} /** * A standalone scheduler backend, which waits for standalone executors to connect to it through @@ -23,8 +24,16 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor { val slaveActor = new HashMap[String, ActorRef] + val slaveAddress = new HashMap[String, Address] val slaveHost = new HashMap[String, String] val freeCores = new HashMap[String, Int] + val actorToSlaveId = new HashMap[ActorRef, String] + val addressToSlaveId = new HashMap[Address, String] + + override def preStart() { + // Listen for remote client disconnection events, since they don't go through Akka's watch() + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + } def receive = { case RegisterSlave(slaveId, host, cores) => @@ -33,9 +42,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } else { logInfo("Registered slave: " + sender + " with ID " + slaveId) sender ! RegisteredSlave(sparkProperties) + context.watch(sender) slaveActor(slaveId) = sender slaveHost(slaveId) = host freeCores(slaveId) = cores + slaveAddress(slaveId) = sender.path.address + actorToSlaveId(sender) = slaveId + addressToSlaveId(sender.path.address) = slaveId totalCoreCount.addAndGet(cores) makeOffers() } @@ -54,7 +67,14 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor sender ! true context.stop(self) - // TODO: Deal with nodes disconnecting too! (Including decreasing totalCoreCount) + case Terminated(actor) => + actorToSlaveId.get(actor).foreach(removeSlave) + + case RemoteClientDisconnected(transport, address) => + addressToSlaveId.get(address).foreach(removeSlave) + + case RemoteClientShutdown(transport, address) => + addressToSlaveId.get(address).foreach(removeSlave) } // Make fake resource offers on all slaves @@ -76,6 +96,20 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor slaveActor(task.slaveId) ! LaunchTask(task) } } + + // Remove a disconnected slave from the cluster + def removeSlave(slaveId: String) { + logInfo("Slave " + slaveId + " disconnected, so removing it") + val numCores = freeCores(slaveId) + actorToSlaveId -= slaveActor(slaveId) + addressToSlaveId -= slaveAddress(slaveId) + slaveActor -= slaveId + slaveHost -= slaveId + freeCores -= slaveId + slaveHost -= slaveId + totalCoreCount.addAndGet(-numCores) + scheduler.slaveLost(slaveId) + } } var masterActor: ActorRef = null -- cgit v1.2.3 From 3c9c44a8d36c0f2dff40a50f1f1e3bc3dac7be7e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 19:37:43 -0700 Subject: More helpful log messages --- core/src/main/scala/spark/MapOutputTracker.scala | 3 ++- core/src/main/scala/spark/network/Connection.scala | 4 ++-- core/src/main/scala/spark/network/ConnectionManager.scala | 2 +- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 1 + 4 files changed, 6 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 0c97cd44a1..e249430905 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -116,7 +116,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = { val locs = bmAddresses.get(shuffleId) if (locs == null) { - logInfo("Don't have map outputs for shuffe " + shuffleId + ", fetching them") + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") fetching.synchronized { if (fetching.contains(shuffleId)) { // Someone else is fetching it; wait for them to be done @@ -158,6 +158,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg def incrementGeneration() { generationLock.synchronized { generation += 1 + logInfo("Increasing generation to " + generation) } } diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 451faee66e..da8aff9dd5 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -111,7 +111,7 @@ extends Connection(SocketChannel.open, selector_) { messages.synchronized{ /*messages += message*/ messages.enqueue(message) - logInfo("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") + logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") } } @@ -136,7 +136,7 @@ extends Connection(SocketChannel.open, selector_) { return chunk } /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ - logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) + logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) } } None diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index bd0980029a..0e764fff81 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -306,7 +306,7 @@ class ConnectionManager(port: Int) extends Logging { } val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection()) message.senderAddress = id.toSocketAddress() - logInfo("Sending [" + message + "] to [" + connectionManagerId + "]") + logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") /*connection.send(message)*/ sendMessageRequests.synchronized { sendMessageRequests += ((message, connection)) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index be24316e80..5412e8d8c0 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -88,6 +88,7 @@ class TaskSetManager( // Figure out the current map output tracker generation and set it on all tasks val generation = sched.mapOutputTracker.getGeneration + logInfo("Generation for " + taskSet.id + ": " + generation) for (t <- tasks) { t.generation = generation } -- cgit v1.2.3 From 117e3f8c8602c1303fa0e31840d85d1a7a6e3d9d Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 19:52:56 -0700 Subject: Fix a bug that was causing FetchFailedException not to be thrown --- core/src/main/scala/spark/BlockStoreShuffleFetcher.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index 3431ad2258..45a14c8290 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -48,8 +48,9 @@ class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { } } } catch { + // TODO: this is really ugly -- let's find a better way of throwing a FetchFailedException case be: BlockException => { - val regex = "shuffledid_([0-9]*)_([0-9]*)_([0-9]]*)".r + val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]]*)".r be.blockId match { case regex(sId, mId, rId) => { val address = addresses(mId.toInt) -- cgit v1.2.3 From 69c2ab04083972e4ecf1393ffd0cb0acb56b4f7d Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 20:00:58 -0700 Subject: logging --- core/src/main/scala/spark/executor/Executor.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 9e335c25f7..dba209ac27 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -63,6 +63,7 @@ class Executor extends Logging { Thread.currentThread.setContextClassLoader(classLoader) Accumulators.clear() val task = ser.deserialize[Task[Any]](serializedTask, classLoader) + logInfo("Its generation is " + task.generation) env.mapOutputTracker.updateGeneration(task.generation) val value = task.run(taskId.toInt) val accumUpdates = Accumulators.values -- cgit v1.2.3 From b914cd0dfa21b615c29d2ce935f623f209afa8f4 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 20:07:59 -0700 Subject: Serialize generation correctly in ShuffleMapTask --- core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 2 ++ 1 file changed, 2 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index f78e0e5fb2..73479bff01 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -90,6 +90,7 @@ class ShuffleMapTask( out.writeInt(bytes.length) out.write(bytes) out.writeInt(partition) + out.writeLong(generation) out.writeObject(split) } @@ -102,6 +103,7 @@ class ShuffleMapTask( rdd = rdd_ dep = dep_ partition = in.readInt() + generation = in.readLong() split = in.readObject().asInstanceOf[Split] } -- cgit v1.2.3 From b4a2214218eeb9ebd95b59d88c2212fe717efd9e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 27 Aug 2012 22:49:29 -0700 Subject: More fault tolerance fixes to catch lost tasks --- core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala | 2 ++ .../main/scala/spark/scheduler/cluster/TaskSetManager.scala | 13 +++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala index 0fc1d8ed30..65e59841a9 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala @@ -20,6 +20,8 @@ class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: def successful: Boolean = finished && !failed + def running: Boolean = !finished + def duration: Long = { if (!finished) { throw new UnsupportedOperationException("duration() called on unfinished tasks") diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 5412e8d8c0..17317e80df 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -265,6 +265,11 @@ class TaskSetManager( def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } val index = info.index info.markFailed() if (!finished(index)) { @@ -341,7 +346,7 @@ class TaskSetManager( } def hostLost(hostname: String) { - logInfo("Re-queueing tasks for " + hostname) + logInfo("Re-queueing tasks for " + hostname + " from TaskSet " + taskSet.id) // If some task has preferred locations only on hostname, put it in the no-prefs list // to avoid the wait from delay scheduling for (index <- getPendingTasksForHost(hostname)) { @@ -350,7 +355,7 @@ class TaskSetManager( pendingTasksWithNoPrefs += index } } - // Also re-enqueue any tasks that ran on the failed host if this is a shuffle map stage + // Re-enqueue any tasks that ran on the failed host if this is a shuffle map stage if (tasks(0).isInstanceOf[ShuffleMapTask]) { for ((tid, info) <- taskInfos if info.host == hostname) { val index = taskInfos(tid).index @@ -365,6 +370,10 @@ class TaskSetManager( } } } + // Also re-enqueue any tasks that were running on the node + for ((tid, info) <- taskInfos if info.running && info.host == hostname) { + taskLost(tid, TaskState.KILLED, null) + } } /** -- cgit v1.2.3 From 17af2df0cdcdc4f02013bd7b4351e0a9d9ee9b25 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 27 Aug 2012 23:07:32 -0700 Subject: Log levels --- core/src/main/scala/spark/MapOutputTracker.scala | 2 +- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index e249430905..de23eb6f48 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -158,7 +158,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg def incrementGeneration() { generationLock.synchronized { generation += 1 - logInfo("Increasing generation to " + generation) + logDebug("Increasing generation to " + generation) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 17317e80df..5a7df6040c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -88,7 +88,7 @@ class TaskSetManager( // Figure out the current map output tracker generation and set it on all tasks val generation = sched.mapOutputTracker.getGeneration - logInfo("Generation for " + taskSet.id + ": " + generation) + logDebug("Generation for " + taskSet.id + ": " + generation) for (t <- tasks) { t.generation = generation } -- cgit v1.2.3 From 4db3a967669a53de4c4b79b4c0b70daa5accb682 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 29 Aug 2012 13:04:01 -0700 Subject: Made minor changes to reduce compilation errors in Eclipse. Twirl stuff still does not compile in Eclipse. --- .../src/main/scala/spark/network/ConnectionManager.scala | 16 +++++++++++++--- .../main/scala/spark/network/ConnectionManagerTest.scala | 5 ++++- 2 files changed, 17 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 0e764fff81..2bb5f5fc6b 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -16,6 +16,7 @@ import scala.collection.mutable.ArrayBuffer import akka.dispatch.{Await, Promise, ExecutionContext, Future} import akka.util.Duration +import akka.util.duration._ case class ConnectionManagerId(host: String, port: Int) { def toSocketAddress() = new InetSocketAddress(host, port) @@ -403,7 +404,10 @@ object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) val finishTime = System.currentTimeMillis val mb = size * count / 1024.0 / 1024.0 @@ -430,7 +434,10 @@ object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) val finishTime = System.currentTimeMillis val ms = finishTime - startTime @@ -457,7 +464,10 @@ object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) val finishTime = System.currentTimeMillis Thread.sleep(1000) val mb = size * count / 1024.0 / 1024.0 diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala index 5d21bb793f..555b3454ee 100644 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala @@ -8,6 +8,9 @@ import scala.io.Source import java.nio.ByteBuffer import java.net.InetAddress +import akka.dispatch.Await +import akka.util.duration._ + object ConnectionManagerTest extends Logging{ def main(args: Array[String]) { if (args.length < 2) { @@ -53,7 +56,7 @@ object ConnectionManagerTest extends Logging{ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) }) - val results = futures.map(f => f()) + val results = futures.map(f => Await.result(f, 1.second)) val finishTime = System.currentTimeMillis Thread.sleep(5000) -- cgit v1.2.3 From c4366eb76425d1c6aeaa7df750a2681a0da75db8 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 31 Aug 2012 00:34:24 +0000 Subject: Fixes to ShuffleFetcher --- .../scala/spark/BlockStoreShuffleFetcher.scala | 41 +++++++++------------- 1 file changed, 17 insertions(+), 24 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index 45a14c8290..0bbdb4e432 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -32,36 +32,29 @@ class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { (address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId))) } - try { - for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) { - blockOption match { - case Some(block) => { - val values = block - for(value <- values) { - val v = value.asInstanceOf[(K, V)] - func(v._1, v._2) - } - } - case None => { - throw new BlockException(blockId, "Did not get block " + blockId) + for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) { + blockOption match { + case Some(block) => { + val values = block + for(value <- values) { + val v = value.asInstanceOf[(K, V)] + func(v._1, v._2) } } - } - } catch { - // TODO: this is really ugly -- let's find a better way of throwing a FetchFailedException - case be: BlockException => { - val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]]*)".r - be.blockId match { - case regex(sId, mId, rId) => { - val address = addresses(mId.toInt) - throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be) - } - case _ => { - throw be + case None => { + val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]*)".r + blockId match { + case regex(shufId, mapId, reduceId) => + val addr = addresses(mapId.toInt) + throw new FetchFailedException(addr, shufId.toInt, mapId.toInt, reduceId.toInt, null) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block") } } } } + logDebug("Fetching and merging outputs of shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) } -- cgit v1.2.3 From 1b3e3352ebfed40881d534cd3096d4b6428c24d4 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 30 Aug 2012 17:59:25 -0700 Subject: Deserialize multi-get results in the caller's thread. This fixes an issue with shared buffers with the KryoSerializer. --- core/src/main/scala/spark/storage/BlockManager.scala | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 45f99717bc..e9197f7169 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -272,11 +272,15 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val totalBlocks = blocksByAddress.map(_._2.size).sum logDebug("Getting " + totalBlocks + " blocks") var startTime = System.currentTimeMillis - val results = new LinkedBlockingQueue[(String, Option[Iterator[Any]])] val localBlockIds = new ArrayBuffer[String]() val remoteBlockIds = new ArrayBuffer[String]() val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]() + // A queue to hold our results. Because we want all the deserializing the happen in the + // caller's thread, this will actually hold functions to produce the Iterator for each block. + // For local blocks we'll have an iterator already, while for remote ones we'll deserialize. + val results = new LinkedBlockingQueue[(String, Option[() => Iterator[Any]])] + // Split local and remote blocks for ((address, blockIds) <- blocksByAddress) { if (address == blockManagerId) { @@ -302,10 +306,8 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new SparkException( "Unexpected message " + blockMessage.getType + " received from " + cmId) } - val buffer = blockMessage.getData val blockId = blockMessage.getId - val block = dataDeserialize(buffer) - results.put((blockId, Some(block))) + results.put((blockId, Some(() => dataDeserialize(blockMessage.getData)))) logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) }) } @@ -323,9 +325,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Get the local blocks while remote blocks are being fetched startTime = System.currentTimeMillis localBlockIds.foreach(id => { - get(id) match { + getLocal(id) match { case Some(block) => { - results.put((id, Some(block))) + results.put((id, Some(() => block))) logDebug("Got local block " + id) } case None => { @@ -343,7 +345,8 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 - results.take() + val (blockId, functionOption) = results.take() + (blockId, functionOption.map(_.apply())) } } } -- cgit v1.2.3 From 101ae493e26693146114ac01d50d411f5b2e0762 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 30 Aug 2012 22:24:14 -0700 Subject: Replicate serialized blocks properly, without sharing a ByteBuffer. --- core/src/main/scala/spark/storage/BlockManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index e9197f7169..8a013230da 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -456,7 +456,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // data is already serialized and ready for sending val replicationFuture = if (level.replication > 1) { Future { - replicate(blockId, bytes, level) + replicate(blockId, bytes.duplicate(), level) } } else { null -- cgit v1.2.3 From 113277549c5ee1bcd58c7cebc365d28d92b74b4a Mon Sep 17 00:00:00 2001 From: root Date: Fri, 31 Aug 2012 05:39:35 +0000 Subject: Really fixed the replication-3 issue. The problem was a few buffers not being rewound. --- core/src/main/scala/spark/storage/BlockManager.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 8a013230da..f2d9499bad 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -455,8 +455,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Initiate the replication before storing it locally. This is faster as // data is already serialized and ready for sending val replicationFuture = if (level.replication > 1) { + val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper Future { - replicate(blockId, bytes.duplicate(), level) + replicate(blockId, bufferView, level) } } else { null @@ -514,15 +515,16 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m var peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) for (peer: BlockManagerId <- peers) { val start = System.nanoTime + data.rewind() logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " - + data.array().length + " Bytes. To node: " + peer) + + data.limit() + " Bytes. To node: " + peer) if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), new ConnectionManagerId(peer.ip, peer.port))) { logError("Failed to call syncPutBlock to " + peer) } logDebug("Replicated BlockId " + blockId + " once used " + (System.nanoTime - start) / 1e6 + " s; The size of the data is " + - data.array().length + " bytes.") + data.limit() + " bytes.") } } -- cgit v1.2.3 From c42e7ac2822f697a355650a70379d9e4ce2022c0 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 1 Sep 2012 04:31:11 +0000 Subject: More block manager fixes --- .../scala/spark/storage/BlockManagerWorker.scala | 2 +- core/src/main/scala/spark/storage/BlockStore.scala | 30 ++++++++++------------ 2 files changed, 15 insertions(+), 17 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala index d74cdb38a8..0658a57187 100644 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala @@ -73,7 +73,7 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging { logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) blockManager.putBytes(id, bytes, level) logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.array().length) + + " with data size: " + bytes.limit) } private def getBlock(id: String): ByteBuffer = { diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala index 17f4f51aa8..77e0ed84c5 100644 --- a/core/src/main/scala/spark/storage/BlockStore.scala +++ b/core/src/main/scala/spark/storage/BlockStore.scala @@ -76,11 +76,11 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) currentMemory += sizeEstimate logDebug("Block " + blockId + " stored as values to memory") } else { - val entry = new Entry(bytes, bytes.array().length, false) - ensureFreeSpace(bytes.array.length) + val entry = new Entry(bytes, bytes.limit, false) + ensureFreeSpace(bytes.limit) memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += bytes.array().length - logDebug("Block " + blockId + " stored as " + bytes.array().length + " bytes to memory") + currentMemory += bytes.limit + logDebug("Block " + blockId + " stored as " + bytes.limit + " bytes to memory") } } @@ -97,11 +97,11 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) return Left(elements.iterator) } else { val bytes = dataSerialize(values) - ensureFreeSpace(bytes.array().length) - val entry = new Entry(bytes, bytes.array().length, false) + ensureFreeSpace(bytes.limit) + val entry = new Entry(bytes, bytes.limit, false) memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += bytes.array().length - logDebug("Block " + blockId + " stored as " + bytes.array.length + " bytes to memory") + currentMemory += bytes.limit + logDebug("Block " + blockId + " stored as " + bytes.limit + " bytes to memory") return Right(bytes) } } @@ -118,7 +118,7 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (entry.deserialized) { return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].toIterator) } else { - return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer])) + return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer].duplicate())) } } @@ -199,11 +199,11 @@ class DiskStore(blockManager: BlockManager, rootDirs: String) val file = createFile(blockId) if (file != null) { val channel = new RandomAccessFile(file, "rw").getChannel() - val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.array.length) - buffer.put(bytes.array) + val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.limit) + buffer.put(bytes) channel.close() val finishTime = System.currentTimeMillis - logDebug("Block " + blockId + " stored to file of " + bytes.array.length + " bytes to disk in " + (finishTime - startTime) + " ms") + logDebug("Block " + blockId + " stored to file of " + bytes.limit + " bytes to disk in " + (finishTime - startTime) + " ms") } else { logError("File not created for block " + blockId) } @@ -211,7 +211,7 @@ class DiskStore(blockManager: BlockManager, rootDirs: String) def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = { val bytes = dataSerialize(values) - logDebug("Converted block " + blockId + " to " + bytes.array.length + " bytes") + logDebug("Converted block " + blockId + " to " + bytes.limit + " bytes") putBytes(blockId, bytes, level) return Right(bytes) } @@ -220,9 +220,7 @@ class DiskStore(blockManager: BlockManager, rootDirs: String) val file = getFile(blockId) val length = file.length().toInt val channel = new RandomAccessFile(file, "r").getChannel() - val bytes = ByteBuffer.allocate(length) - bytes.put(channel.map(MapMode.READ_WRITE, 0, length)) - return Some(bytes) + Some(channel.map(MapMode.READ_WRITE, 0, length)) } def getValues(blockId: String): Option[Iterator[Any]] = { -- cgit v1.2.3 From 44758aa8e2337364610ee80fa9ec913301712078 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 1 Sep 2012 00:17:59 -0700 Subject: First work towards a RawInputDStream and a sender program for it. --- .../src/main/scala/spark/DaemonThreadFactory.scala | 12 +- .../scala/spark/util/RateLimitedOutputStream.scala | 56 +++++ .../spark/streaming/NetworkInputDStream.scala | 33 +-- .../spark/streaming/NetworkInputReceiver.scala | 248 --------------------- .../streaming/NetworkInputReceiverMessage.scala | 7 + .../spark/streaming/NetworkInputTracker.scala | 7 +- .../scala/spark/streaming/ObjectInputDStream.scala | 16 ++ .../spark/streaming/ObjectInputReceiver.scala | 244 ++++++++++++++++++++ .../scala/spark/streaming/RawInputDStream.scala | 114 ++++++++++ .../scala/spark/streaming/StreamingContext.scala | 8 +- .../scala/spark/streaming/util/RawTextSender.scala | 51 +++++ 11 files changed, 512 insertions(+), 284 deletions(-) create mode 100644 core/src/main/scala/spark/util/RateLimitedOutputStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala create mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala create mode 100644 streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala create mode 100644 streaming/src/main/scala/spark/streaming/RawInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/RawTextSender.scala (limited to 'core') diff --git a/core/src/main/scala/spark/DaemonThreadFactory.scala b/core/src/main/scala/spark/DaemonThreadFactory.scala index 003880c5e8..56e59adeb7 100644 --- a/core/src/main/scala/spark/DaemonThreadFactory.scala +++ b/core/src/main/scala/spark/DaemonThreadFactory.scala @@ -6,9 +6,13 @@ import java.util.concurrent.ThreadFactory * A ThreadFactory that creates daemon threads */ private object DaemonThreadFactory extends ThreadFactory { - override def newThread(r: Runnable): Thread = { - val t = new Thread(r) - t.setDaemon(true) - return t + override def newThread(r: Runnable): Thread = new DaemonThread(r) +} + +private class DaemonThread(r: Runnable = null) extends Thread { + override def run() { + if (r != null) { + r.run() + } } } \ No newline at end of file diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala new file mode 100644 index 0000000000..10f2272707 --- /dev/null +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -0,0 +1,56 @@ +package spark.util + +import java.io.OutputStream + +class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { + var lastSyncTime = System.nanoTime() + var bytesWrittenSinceSync: Long = 0 + + override def write(b: Int) { + waitToWrite(1) + out.write(b) + } + + override def write(bytes: Array[Byte]) { + write(bytes, 0, bytes.length) + } + + override def write(bytes: Array[Byte], offset: Int, length: Int) { + val CHUNK_SIZE = 8192 + var pos = 0 + while (pos < length) { + val writeSize = math.min(length - pos, CHUNK_SIZE) + waitToWrite(writeSize) + out.write(bytes, offset + pos, length - pos) + pos += writeSize + } + } + + def waitToWrite(numBytes: Int) { + while (true) { + val now = System.nanoTime() + val elapsed = math.max(now - lastSyncTime, 1) + val rate = bytesWrittenSinceSync.toDouble / (elapsed / 1.0e9) + if (rate < bytesPerSec) { + // It's okay to write; just update some variables and return + bytesWrittenSinceSync += numBytes + if (now > lastSyncTime + (1e10).toLong) { + // Ten seconds have passed since lastSyncTime; let's resync + lastSyncTime = now + bytesWrittenSinceSync = numBytes + } + return + } else { + Thread.sleep(5) + } + } + } + + override def flush() { + out.flush() + } + + override def close() { + out.close() + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index ee09324c8c..bf83f98ec4 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -1,36 +1,23 @@ package spark.streaming -import akka.actor._ -import akka.pattern.ask -import akka.util.duration._ -import akka.dispatch._ - import spark.RDD import spark.BlockRDD -import spark.Logging - -import java.io.InputStream +abstract class NetworkInputDStream[T: ClassManifest](@transient ssc: StreamingContext) + extends InputDStream[T](ssc) { -class NetworkInputDStream[T: ClassManifest]( - @transient ssc: StreamingContext, - val host: String, - val port: Int, - val bytesToObjects: InputStream => Iterator[T] - ) extends InputDStream[T](ssc) with Logging { - val id = ssc.getNewNetworkStreamId() - def start() { } + def start() {} - def stop() { } + def stop() {} override def compute(validTime: Time): Option[RDD[T]] = { val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) - return Some(new BlockRDD[T](ssc.sc, blockIds)) + Some(new BlockRDD[T](ssc.sc, blockIds)) } - - def createReceiver(): NetworkInputReceiver[T] = { - new NetworkInputReceiver(id, host, port, bytesToObjects) - } -} \ No newline at end of file + + /** Called on workers to run a receiver for the stream. */ + def runReceiver(): Unit +} + diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala b/streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala deleted file mode 100644 index 7add6246b7..0000000000 --- a/streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala +++ /dev/null @@ -1,248 +0,0 @@ -package spark.streaming - -import spark.Logging -import spark.storage.BlockManager -import spark.storage.StorageLevel -import spark.SparkEnv -import spark.streaming.util.SystemClock -import spark.streaming.util.RecurringTimer - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Queue -import scala.collection.mutable.SynchronizedPriorityQueue -import scala.math.Ordering - -import java.net.InetSocketAddress -import java.net.Socket -import java.io.InputStream -import java.io.BufferedInputStream -import java.io.DataInputStream -import java.io.EOFException -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.ArrayBlockingQueue - -import akka.actor._ -import akka.pattern.ask -import akka.util.duration._ -import akka.dispatch._ - -trait NetworkInputReceiverMessage -case class GetBlockIds(time: Long) extends NetworkInputReceiverMessage -case class GotBlockIds(streamId: Int, blocksIds: Array[String]) extends NetworkInputReceiverMessage -case class StopReceiver() extends NetworkInputReceiverMessage - -class NetworkInputReceiver[T: ClassManifest](streamId: Int, host: String, port: Int, bytesToObjects: InputStream => Iterator[T]) -extends Logging { - - class ReceiverActor extends Actor { - override def preStart() = { - logInfo("Attempting to register") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "NetworkInputTracker" - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - val trackerActor = env.actorSystem.actorFor(url) - val timeout = 100.milliseconds - val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) - } - - def receive = { - case GetBlockIds(time) => { - logInfo("Got request for block ids for " + time) - sender ! GotBlockIds(streamId, dataHandler.getPushedBlocks()) - } - - case StopReceiver() => { - if (receivingThread != null) { - receivingThread.interrupt() - } - sender ! true - } - } - } - - class DataHandler { - - class Block(val time: Long, val iterator: Iterator[T]) { - val blockId = "input-" + streamId + "-" + time - var pushed = true - override def toString() = "input block " + blockId - } - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockOrdering = new Ordering[Block] { - def compare(b1: Block, b2: Block) = (b1.time - b2.time).toInt - } - val blockStorageLevel = StorageLevel.DISK_AND_MEMORY - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blocksForReporting = new SynchronizedPriorityQueue[Block]()(blockOrdering) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - } - - def += (obj: T) { - currentBuffer += obj - } - - def updateCurrentBuffer(time: Long) { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val newBlock = new Block(time - blockInterval, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - blocksForReporting.enqueue(newBlock) - } - } - - def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - if (blockManager != null) { - blockManager.put(block.blockId, block.iterator, blockStorageLevel) - block.pushed = true - } else { - logWarning(block + " not put as block manager is null") - } - } - } catch { - case ie: InterruptedException => println("Block pushing thread interrupted") - case e: Exception => e.printStackTrace() - } - } - - def getPushedBlocks(): Array[String] = { - val pushedBlocks = new ArrayBuffer[String]() - var loop = true - while(loop && !blocksForReporting.isEmpty) { - val block = blocksForReporting.dequeue() - if (block == null) { - loop = false - } else if (!block.pushed) { - blocksForReporting.enqueue(block) - } else { - pushedBlocks += block.blockId - } - } - logInfo("Got " + pushedBlocks.size + " blocks") - pushedBlocks.toArray - } - } - - val blockManager = if (SparkEnv.get != null) SparkEnv.get.blockManager else null - val dataHandler = new DataHandler() - val env = SparkEnv.get - - var receiverActor: ActorRef = null - var receivingThread: Thread = null - - def run() { - initLogging() - var socket: Socket = null - try { - if (SparkEnv.get != null) { - receiverActor = SparkEnv.get.actorSystem.actorOf(Props(new ReceiverActor), "ReceiverActor-" + streamId) - } - dataHandler.start() - socket = connect() - receivingThread = Thread.currentThread() - receive(socket) - } catch { - case ie: InterruptedException => logInfo("Receiver interrupted") - } finally { - receivingThread = null - if (socket != null) socket.close() - dataHandler.stop() - } - } - - def connect(): Socket = { - logInfo("Connecting to " + host + ":" + port) - val socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) - socket - } - - def receive(socket: Socket) { - val iterator = bytesToObjects(socket.getInputStream()) - while(iterator.hasNext) { - val obj = iterator.next - dataHandler += obj - } - } -} - - -object NetworkInputReceiver { - - def bytesToLines(inputStream: InputStream): Iterator[String] = { - val bufferedInputStream = new BufferedInputStream(inputStream) - val dataInputStream = new DataInputStream(bufferedInputStream) - - val iterator = new Iterator[String] { - var gotNext = false - var finished = false - var nextValue: String = null - - private def getNext() { - try { - nextValue = dataInputStream.readLine() - println("[" + nextValue + "]") - } catch { - case eof: EOFException => - finished = true - } - gotNext = true - } - - override def hasNext: Boolean = { - if (!gotNext) { - getNext() - } - if (finished) { - dataInputStream.close() - } - !finished - } - - - override def next(): String = { - if (!gotNext) { - getNext() - } - if (finished) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue - } - } - iterator - } - - def main(args: Array[String]) { - if (args.length < 2) { - println("NetworkReceiver ") - System.exit(1) - } - val host = args(0) - val port = args(1).toInt - val receiver = new NetworkInputReceiver(0, host, port, bytesToLines) - receiver.run() - } -} diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala b/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala new file mode 100644 index 0000000000..deaffe98c8 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala @@ -0,0 +1,7 @@ +package spark.streaming + +sealed trait NetworkInputReceiverMessage + +case class GetBlockIds(time: Long) extends NetworkInputReceiverMessage +case class GotBlockIds(streamId: Int, blocksIds: Array[String]) extends NetworkInputReceiverMessage +case object StopReceiver extends NetworkInputReceiverMessage diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 07758665c9..acf97c1883 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -52,9 +52,7 @@ extends Logging { if (!iterator.hasNext) { throw new Exception("Could not start receiver as details not found.") } - val stream = iterator.next - val receiver = stream.createReceiver() - receiver.run() + iterator.next().runReceiver() } ssc.sc.runJob(tempRDD, startReceiver) @@ -62,8 +60,7 @@ extends Logging { def stopReceivers() { implicit val ec = env.actorSystem.dispatcher - val message = new StopReceiver() - val listOfFutures = receiverInfo.values.map(_.ask(message)(timeout)).toList + val listOfFutures = receiverInfo.values.map(_.ask(StopReceiver)(timeout)).toList val futureOfList = Future.sequence(listOfFutures) Await.result(futureOfList, timeout) } diff --git a/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala b/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala new file mode 100644 index 0000000000..2396b374a0 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala @@ -0,0 +1,16 @@ +package spark.streaming + +import java.io.InputStream + +class ObjectInputDStream[T: ClassManifest]( + @transient ssc: StreamingContext, + val host: String, + val port: Int, + val bytesToObjects: InputStream => Iterator[T]) + extends NetworkInputDStream[T](ssc) { + + override def runReceiver() { + new ObjectInputReceiver(id, host, port, bytesToObjects).run() + } +} + diff --git a/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala b/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala new file mode 100644 index 0000000000..70fa2cdf07 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala @@ -0,0 +1,244 @@ +package spark.streaming + +import spark.Logging +import spark.storage.BlockManager +import spark.storage.StorageLevel +import spark.SparkEnv +import spark.streaming.util.SystemClock +import spark.streaming.util.RecurringTimer + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.Queue +import scala.collection.mutable.SynchronizedPriorityQueue +import scala.math.Ordering + +import java.net.InetSocketAddress +import java.net.Socket +import java.io.InputStream +import java.io.BufferedInputStream +import java.io.DataInputStream +import java.io.EOFException +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.ArrayBlockingQueue + +import akka.actor._ +import akka.pattern.ask +import akka.util.duration._ +import akka.dispatch._ + +class ObjectInputReceiver[T: ClassManifest]( + streamId: Int, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T]) + extends Logging { + + class ReceiverActor extends Actor { + override def preStart() { + logInfo("Attempting to register") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val actorName: String = "NetworkInputTracker" + val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) + val trackerActor = env.actorSystem.actorFor(url) + val timeout = 1.seconds + val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + def receive = { + case GetBlockIds(time) => { + logInfo("Got request for block ids for " + time) + sender ! GotBlockIds(streamId, dataHandler.getPushedBlocks()) + } + + case StopReceiver => { + if (receivingThread != null) { + receivingThread.interrupt() + } + sender ! true + } + } + } + + class DataHandler { + class Block(val time: Long, val iterator: Iterator[T]) { + val blockId = "input-" + streamId + "-" + time + var pushed = true + override def toString = "input block " + blockId + } + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockOrdering = new Ordering[Block] { + def compare(b1: Block, b2: Block) = (b1.time - b2.time).toInt + } + val blockStorageLevel = StorageLevel.DISK_AND_MEMORY + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blocksForReporting = new SynchronizedPriorityQueue[Block]()(blockOrdering) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val newBlock = new Block(time - blockInterval, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + blocksForReporting.enqueue(newBlock) + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + if (blockManager != null) { + blockManager.put(block.blockId, block.iterator, blockStorageLevel) + block.pushed = true + } else { + logWarning(block + " not put as block manager is null") + } + } + } catch { + case ie: InterruptedException => println("Block pushing thread interrupted") + case e: Exception => e.printStackTrace() + } + } + + def getPushedBlocks(): Array[String] = { + val pushedBlocks = new ArrayBuffer[String]() + var loop = true + while(loop && !blocksForReporting.isEmpty) { + val block = blocksForReporting.dequeue() + if (block == null) { + loop = false + } else if (!block.pushed) { + blocksForReporting.enqueue(block) + } else { + pushedBlocks += block.blockId + } + } + logInfo("Got " + pushedBlocks.size + " blocks") + pushedBlocks.toArray + } + } + + val blockManager = if (SparkEnv.get != null) SparkEnv.get.blockManager else null + val dataHandler = new DataHandler() + val env = SparkEnv.get + + var receiverActor: ActorRef = null + var receivingThread: Thread = null + + def run() { + initLogging() + var socket: Socket = null + try { + if (SparkEnv.get != null) { + receiverActor = SparkEnv.get.actorSystem.actorOf(Props(new ReceiverActor), "ReceiverActor-" + streamId) + } + dataHandler.start() + socket = connect() + receivingThread = Thread.currentThread() + receive(socket) + } catch { + case ie: InterruptedException => logInfo("Receiver interrupted") + } finally { + receivingThread = null + if (socket != null) socket.close() + dataHandler.stop() + } + } + + def connect(): Socket = { + logInfo("Connecting to " + host + ":" + port) + val socket = new Socket(host, port) + logInfo("Connected to " + host + ":" + port) + socket + } + + def receive(socket: Socket) { + val iterator = bytesToObjects(socket.getInputStream()) + while(iterator.hasNext) { + val obj = iterator.next + dataHandler += obj + } + } +} + + +object ObjectInputReceiver { + def bytesToLines(inputStream: InputStream): Iterator[String] = { + val bufferedInputStream = new BufferedInputStream(inputStream) + val dataInputStream = new DataInputStream(bufferedInputStream) + + val iterator = new Iterator[String] { + var gotNext = false + var finished = false + var nextValue: String = null + + private def getNext() { + try { + nextValue = dataInputStream.readLine() + println("[" + nextValue + "]") + } catch { + case eof: EOFException => + finished = true + } + gotNext = true + } + + override def hasNext: Boolean = { + if (!gotNext) { + getNext() + } + if (finished) { + dataInputStream.close() + } + !finished + } + + override def next(): String = { + if (!gotNext) { + getNext() + } + if (finished) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } + } + iterator + } + + def main(args: Array[String]) { + if (args.length < 2) { + println("ObjectInputReceiver ") + System.exit(1) + } + val host = args(0) + val port = args(1).toInt + val receiver = new ObjectInputReceiver(0, host, port, bytesToLines) + receiver.run() + } +} diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala new file mode 100644 index 0000000000..49e4781e75 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -0,0 +1,114 @@ +package spark.streaming + +import akka.actor._ +import akka.pattern.ask +import akka.util.duration._ +import akka.dispatch._ +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, SocketChannel} +import java.io.EOFException +import java.util.concurrent.ArrayBlockingQueue +import scala.collection.mutable.ArrayBuffer +import spark.{DaemonThread, Logging, SparkEnv} +import spark.storage.StorageLevel + +/** + * An input stream that reads blocks of serialized objects from a given network address. + * The blocks will be inserted directly into the block store. This is the fastest way to get + * data into Spark Streaming, though it requires the sender to batch data and serialize it + * in the format that the system is configured with. + */ +class RawInputDStream[T: ClassManifest]( + @transient ssc: StreamingContext, + host: String, + port: Int) + extends NetworkInputDStream[T](ssc) with Logging { + + val streamId = id + + /** Called on workers to run a receiver for the stream. */ + def runReceiver() { + val env = SparkEnv.get + val actor = env.actorSystem.actorOf( + Props(new ReceiverActor(env, Thread.currentThread)), "ReceiverActor-" + streamId) + + // Open a socket to the target address and keep reading from it + logInfo("Connecting to " + host + ":" + port) + val channel = SocketChannel.open() + channel.configureBlocking(true) + channel.connect(new InetSocketAddress(host, port)) + logInfo("Connected to " + host + ":" + port) + + val queue = new ArrayBlockingQueue[ByteBuffer](2) + + new DaemonThread { + override def run() { + var nextBlockNumber = 0 + while (true) { + val buffer = queue.take() + val blockId = "input-" + streamId + "-" + nextBlockNumber + nextBlockNumber += 1 + env.blockManager.putBytes(blockId, buffer, StorageLevel.MEMORY_ONLY_2) + actor ! BlockPublished(blockId) + } + } + }.start() + + val lengthBuffer = ByteBuffer.allocate(4) + while (true) { + lengthBuffer.clear() + readFully(channel, lengthBuffer) + lengthBuffer.flip() + val length = lengthBuffer.getInt() + val dataBuffer = ByteBuffer.allocate(length) + readFully(channel, dataBuffer) + dataBuffer.flip() + logInfo("Read a block with " + length + " bytes") + queue.put(dataBuffer) + } + } + + /** Read a buffer fully from a given Channel */ + private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { + while (dest.position < dest.limit) { + if (channel.read(dest) == -1) { + throw new EOFException("End of channel") + } + } + } + + /** Message sent to ReceiverActor to tell it that a block was published */ + case class BlockPublished(blockId: String) {} + + /** A helper actor that communicates with the NetworkInputTracker */ + private class ReceiverActor(env: SparkEnv, receivingThread: Thread) extends Actor { + val newBlocks = new ArrayBuffer[String] + + override def preStart() { + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val actorName: String = "NetworkInputTracker" + val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) + val trackerActor = env.actorSystem.actorFor(url) + val timeout = 1.seconds + val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + override def receive = { + case BlockPublished(blockId) => + newBlocks += blockId + + case GetBlockIds(time) => + logInfo("Got request for block IDs for " + time) + sender ! GotBlockIds(streamId, newBlocks.toArray) + newBlocks.clear() + + case StopReceiver => + receivingThread.interrupt() + sender ! true + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 0ac86cbdf2..feb769e036 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -52,22 +52,22 @@ class StreamingContext ( private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() def createNetworkTextStream(hostname: String, port: Int): DStream[String] = { - createNetworkStream[String](hostname, port, NetworkInputReceiver.bytesToLines) + createNetworkObjectStream[String](hostname, port, ObjectInputReceiver.bytesToLines) } - def createNetworkStream[T: ClassManifest]( + def createNetworkObjectStream[T: ClassManifest]( hostname: String, port: Int, converter: (InputStream) => Iterator[T] ): DStream[T] = { - val inputStream = new NetworkInputDStream[T](this, hostname, port, converter) + val inputStream = new ObjectInputDStream[T](this, hostname, port, converter) inputStreams += inputStream inputStream } /* def createHttpTextStream(url: String): DStream[String] = { - createHttpStream(url, NetworkInputReceiver.bytesToLines) + createHttpStream(url, ObjectInputReceiver.bytesToLines) } def createHttpStream[T: ClassManifest]( diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala new file mode 100644 index 0000000000..60d5849d71 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -0,0 +1,51 @@ +package spark.streaming.util + +import spark.util.{RateLimitedOutputStream, IntParam} +import java.net.ServerSocket +import spark.{Logging, KryoSerializer} +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import io.Source +import java.io.IOException + +/** + * A helper program that sends blocks of Kryo-serialized text strings out on a socket at a + * specified rate. Used to feed data into RawInputDStream. + */ +object RawTextSender extends Logging { + def main(args: Array[String]) { + if (args.length != 4) { + System.err.println("Usage: RawTextSender ") + } + // Parse the arguments using a pattern match + val Array(IntParam(port), file, IntParam(blockSize), IntParam(bytesPerSec)) = args + + // Repeat the input data multiple times to fill in a buffer + val lines = Source.fromFile(file).getLines().toArray + val bufferStream = new FastByteArrayOutputStream(blockSize + 1000) + val ser = new KryoSerializer().newInstance() + val serStream = ser.serializeStream(bufferStream) + var i = 0 + while (bufferStream.position < blockSize) { + serStream.writeObject(lines(i)) + i = (i + 1) % lines.length + } + bufferStream.trim() + val array = bufferStream.array + + val serverSocket = new ServerSocket(port) + + while (true) { + val socket = serverSocket.accept() + val out = new RateLimitedOutputStream(socket.getOutputStream, bytesPerSec) + try { + while (true) { + out.write(array) + } + } catch { + case e: IOException => + logError("Socket closed: ", e) + socket.close() + } + } + } +} -- cgit v1.2.3 From f84d2bbe55aaf3ef7a6631b9018a573aa5729ff7 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 1 Sep 2012 00:31:15 -0700 Subject: Bug fixes to RateLimitedOutputStream --- core/src/main/scala/spark/util/RateLimitedOutputStream.scala | 2 +- streaming/src/main/scala/spark/streaming/util/RawTextSender.scala | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index 10f2272707..d11ed163ce 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -21,7 +21,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu while (pos < length) { val writeSize = math.min(length - pos, CHUNK_SIZE) waitToWrite(writeSize) - out.write(bytes, offset + pos, length - pos) + out.write(bytes, offset + pos, writeSize) pos += writeSize } } diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala index 60d5849d71..85927c02ec 100644 --- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -15,6 +15,7 @@ object RawTextSender extends Logging { def main(args: Array[String]) { if (args.length != 4) { System.err.println("Usage: RawTextSender ") + System.exit(1) } // Parse the arguments using a pattern match val Array(IntParam(port), file, IntParam(blockSize), IntParam(bytesPerSec)) = args @@ -36,6 +37,7 @@ object RawTextSender extends Logging { while (true) { val socket = serverSocket.accept() + logInfo("Got a new connection") val out = new RateLimitedOutputStream(socket.getOutputStream, bytesPerSec) try { while (true) { @@ -43,7 +45,7 @@ object RawTextSender extends Logging { } } catch { case e: IOException => - logError("Socket closed: ", e) + logError("Socket closed", e) socket.close() } } -- cgit v1.2.3 From 6025889be0ecf1c9849c5c940a7171c6d82be0b5 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 1 Sep 2012 20:51:07 +0000 Subject: More raw network receiver programs --- .../mesos/CoarseMesosSchedulerBackend.scala | 4 ++- .../scala/spark/streaming/examples/GrepRaw.scala | 33 +++++++++++++++++ .../spark/streaming/examples/WordCount2.scala | 2 +- .../spark/streaming/examples/WordCountRaw.scala | 42 ++++++++++++++++++++++ 4 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index 31784985dc..fdf007ffb2 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -80,6 +80,8 @@ class CoarseMesosSchedulerBackend( "property, the SPARK_HOME environment variable or the SparkContext constructor") } + val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt + var nextMesosTaskId = 0 def newMesosTaskId(): Int = { @@ -177,7 +179,7 @@ class CoarseMesosSchedulerBackend( val task = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse)) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) .addResources(createResource("mem", executorMemory)) diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala new file mode 100644 index 0000000000..cc52da7bd4 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -0,0 +1,33 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel +import spark.streaming._ +import spark.streaming.StreamingContext._ + +object GrepRaw { + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("Usage: GrepRaw ") + System.exit(1) + } + + val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "GrepRaw") + ssc.setBatchDuration(Milliseconds(batchMillis)) + + // Make sure some tasks have started on each node + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + + val rawStreams = (1 to numStreams).map(_ => + ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + val union = new UnifiedDStream(rawStreams) + union.filter(_.contains("Culpepper")).count().foreachRDD(r => + println("Grep count: " + r.collect().mkString)) + ssc.start() + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index ce553758a7..8c2724e97c 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -100,7 +100,7 @@ object WordCount2 { .reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt) windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Milliseconds(chkptMillis.toLong)) - windowedCounts.print() + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala new file mode 100644 index 0000000000..298d9ef381 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -0,0 +1,42 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel +import spark.streaming._ +import spark.streaming.StreamingContext._ + +object WordCountRaw { + def main(args: Array[String]) { + if (args.length != 7) { + System.err.println("Usage: WordCountRaw ") + System.exit(1) + } + + val Array(master, IntParam(streams), host, IntParam(port), IntParam(batchMs), + IntParam(chkptMs), IntParam(reduces)) = args + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "WordCountRaw") + ssc.setBatchDuration(Milliseconds(batchMs)) + + // Make sure some tasks have started on each node + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + + val rawStreams = (1 to streams).map(_ => + ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + val union = new UnifiedDStream(rawStreams) + + import WordCount2_ExtraFunctions._ + + val windowedCounts = union.mapPartitions(splitAndCountPartitions) + .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, + Milliseconds(chkptMs)) + //windowedCounts.print() // TODO: something else? + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) + + ssc.start() + } +} -- cgit v1.2.3 From ceabf71257631c9e46f82897e540369b99a6bb57 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 1 Sep 2012 21:52:42 +0000 Subject: tweaks --- core/src/main/scala/spark/storage/StorageLevel.scala | 1 + streaming/src/main/scala/spark/streaming/util/RawTextSender.scala | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index f067a2a6c5..a64393eba7 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -66,6 +66,7 @@ class StorageLevel( object StorageLevel { val NONE = new StorageLevel(false, false, false) val DISK_ONLY = new StorageLevel(true, false, false) + val DISK_ONLY_2 = new StorageLevel(true, false, false, 2) val MEMORY_ONLY = new StorageLevel(false, true, false) val MEMORY_ONLY_2 = new StorageLevel(false, true, false, 2) val MEMORY_ONLY_DESER = new StorageLevel(false, true, true) diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala index 8db651ba19..d8b987ec86 100644 --- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -52,7 +52,7 @@ object RawTextSender extends Logging { } } catch { case e: IOException => - logError("Socket closed", e) + logError("Client disconnected") socket.close() } } -- cgit v1.2.3 From 7c09ad0e04639040864236cf13a9fedff6736b5d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 4 Sep 2012 19:11:49 -0700 Subject: Changed DStream member access permissions from private to protected. Updated StateDStream to checkpoint RDDs and forget lineage. --- core/src/main/scala/spark/RDD.scala | 2 +- .../src/main/scala/spark/streaming/DStream.scala | 16 ++-- .../scala/spark/streaming/QueueInputDStream.scala | 2 +- .../main/scala/spark/streaming/StateDStream.scala | 93 ++++++++++++++++------ .../test/scala/spark/streaming/DStreamSuite.scala | 4 +- 5 files changed, 81 insertions(+), 36 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 3fe8e8a4bf..d28f3593fe 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -94,7 +94,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def getStorageLevel = storageLevel - def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER): RDD[T] = { + def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER_2): RDD[T] = { if (!level.useDisk && level.replication < 2) { throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 9b0115eef6..20f1c4db20 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -41,17 +41,17 @@ extends Logging with Serializable { */ // Variable to store the RDDs generated earlier in time - @transient private val generatedRDDs = new HashMap[Time, RDD[T]] () + @transient protected val generatedRDDs = new HashMap[Time, RDD[T]] () // Variable to be set to the first time seen by the DStream (effective time zero) - private[streaming] var zeroTime: Time = null + protected[streaming] var zeroTime: Time = null // Variable to specify storage level - private var storageLevel: StorageLevel = StorageLevel.NONE + protected var storageLevel: StorageLevel = StorageLevel.NONE // Checkpoint level and checkpoint interval - private var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint - private var checkpointInterval: Time = null + protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint + protected var checkpointInterval: Time = null // Change this RDD's storage level def persist( @@ -84,7 +84,7 @@ extends Logging with Serializable { * the validity of future times is calculated. This method also recursively initializes * its parent DStreams. */ - def initialize(time: Time) { + protected[streaming] def initialize(time: Time) { if (zeroTime == null) { zeroTime = time } @@ -93,7 +93,7 @@ extends Logging with Serializable { } /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ - private def isTimeValid (time: Time): Boolean = { + protected def isTimeValid (time: Time): Boolean = { if (!isInitialized) { throw new Exception (this.toString + " has not been initialized") } else if (time < zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) { @@ -208,7 +208,7 @@ extends Logging with Serializable { new TransformedDStream(this, ssc.sc.clean(transformFunc)) } - private[streaming] def toQueue = { + def toQueue = { val queue = new ArrayBlockingQueue[RDD[T]](10000) this.foreachRDD(rdd => { queue.add(rdd) diff --git a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala index bab48ff954..de30297c7d 100644 --- a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala @@ -7,7 +7,7 @@ import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer class QueueInputDStream[T: ClassManifest]( - ssc: StreamingContext, + @transient ssc: StreamingContext, val queue: Queue[RDD[T]], oneAtATime: Boolean, defaultRDD: RDD[T] diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index f313d8c162..4cb780c006 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -1,10 +1,11 @@ package spark.streaming import spark.RDD +import spark.BlockRDD import spark.Partitioner import spark.MapPartitionsRDD import spark.SparkContext._ - +import spark.storage.StorageLevel class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( parent: DStream[(K, V)], @@ -22,6 +23,47 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife override def slideTime = parent.slideTime + override def getOrCompute(time: Time): Option[RDD[(K, S)]] = { + generatedRDDs.get(time) match { + case Some(oldRDD) => { + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) { + val r = oldRDD + val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index) + val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) { + override val partitioner = oldRDD.partitioner + } + generatedRDDs.update(time, checkpointedRDD) + logInfo("Updated RDD of time " + time + " with its checkpointed version") + Some(checkpointedRDD) + } else { + Some(oldRDD) + } + } + case None => { + if (isTimeValid(time)) { + compute(time) match { + case Some(newRDD) => { + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.persist(checkpointLevel) + logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) + } else if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) + } + generatedRDDs.put(time, newRDD) + Some(newRDD) + } + case None => { + None + } + } + } else { + None + } + } + } + } + override def compute(validTime: Time): Option[RDD[(K, S)]] = { // Try to get the previous state RDD @@ -29,26 +71,27 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife case Some(prevStateRDD) => { // If previous state RDD exists - // Define the function for the mapPartition operation on cogrouped RDD; - // first map the cogrouped tuple to tuples of required type, - // and then apply the update function - val func = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { - val i = iterator.map(t => { - (t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S])) - }) - updateFunc(i) - } - // Try to get the parent RDD parent.getOrCompute(validTime) match { case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on cogrouped RDD; + // first map the cogrouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val mapPartitionFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val i = iterator.map(t => { + (t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S])) + }) + updateFuncLocal(i) + } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) - val stateRDD = new SpecialMapPartitionsRDD(cogroupedRDD, func) - logDebug("Generating state RDD for time " + validTime) + val stateRDD = new SpecialMapPartitionsRDD(cogroupedRDD, mapPartitionFunc) + //logDebug("Generating state RDD for time " + validTime) return Some(stateRDD) } case None => { // If parent RDD does not exist, then return old state RDD - logDebug("Generating state RDD for time " + validTime + " (no change)") + //logDebug("Generating state RDD for time " + validTime + " (no change)") return Some(prevStateRDD) } } @@ -56,23 +99,25 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife case None => { // If previous session RDD does not exist (first input data) - // Define the function for the mapPartition operation on grouped RDD; - // first map the grouped tuple to tuples of required type, - // and then apply the update function - val func = (iterator: Iterator[(K, Seq[V])]) => { - updateFunc(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S]))) - } - // Try to get the parent RDD parent.getOrCompute(validTime) match { case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on grouped RDD; + // first map the grouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val mapPartitionFunc = (iterator: Iterator[(K, Seq[V])]) => { + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S]))) + } + val groupedRDD = parentRDD.groupByKey(partitioner) - val sessionRDD = new SpecialMapPartitionsRDD(groupedRDD, func) - logDebug("Generating state RDD for time " + validTime + " (first)") + val sessionRDD = new SpecialMapPartitionsRDD(groupedRDD, mapPartitionFunc) + //logDebug("Generating state RDD for time " + validTime + " (first)") return Some(sessionRDD) } case None => { // If parent RDD does not exist, then nothing to do! - logDebug("Not generating state RDD (no previous state, no parent)") + //logDebug("Not generating state RDD (no previous state, no parent)") return None } } diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala index 030f351080..fc00952afe 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala @@ -107,12 +107,12 @@ class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { Seq(("a", 3), ("b", 3), ("c", 3)) ) - val updateStateOp =(s: DStream[String]) => { + val updateStateOp = (s: DStream[String]) => { val updateFunc = (values: Seq[Int], state: RichInt) => { var newState = 0 if (values != null) newState += values.reduce(_ + _) if (state != null) newState += state.self - //println("values = " + values + ", state = " + state + ", " + " new state = " + newState) + println("values = " + values + ", state = " + state + ", " + " new state = " + newState) new RichInt(newState) } s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) -- cgit v1.2.3 From 4ea032a142ab7fb44f92b145cc8d850164419ab5 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 05:53:07 +0000 Subject: Some changes to make important log output visible even if we set the logging to WARNING --- .../main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala | 2 +- streaming/src/main/scala/spark/streaming/JobManager.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 83e7c6e036..978b4f2676 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -99,7 +99,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Remove a disconnected slave from the cluster def removeSlave(slaveId: String) { - logInfo("Slave " + slaveId + " disconnected, so removing it") + logWarning("Slave " + slaveId + " disconnected, so removing it") val numCores = freeCores(slaveId) actorToSlaveId -= slaveActor(slaveId) addressToSlaveId -= slaveAddress(slaveId) diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 9bf9251519..230d806a89 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -12,7 +12,7 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { SparkEnv.set(ssc.env) try { val timeTaken = job.run() - logInfo("Total delay: %.5f s for job %s (execution: %.5f s)".format( + println("Total delay: %.5f s for job %s (execution: %.5f s)".format( (System.currentTimeMillis() - job.time) / 1000.0, job.id, timeTaken / 1000.0)) } catch { case e: Exception => -- cgit v1.2.3 From b7ad291ac52896af6cb1d882392f3d6fa0cf3b49 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 07:08:07 +0000 Subject: Tuning Akka for more connections --- core/src/main/scala/spark/util/AkkaUtils.scala | 1 + .../src/main/scala/spark/streaming/examples/WordCount2.scala | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 57d212e4ca..fd64e224d7 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -31,6 +31,7 @@ object AkkaUtils { akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = 1s + akka.remote.netty.execution-pool-size = 10 """.format(host, port)) val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index aa542ba07d..8561e7f079 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -62,10 +62,10 @@ object WordCount2_ExtraFunctions { object WordCount2 { def warmup(sc: SparkContext) { - (0 until 10).foreach {i => - sc.parallelize(1 to 20000000, 1000) + (0 until 3).foreach {i => + sc.parallelize(1 to 20000000, 500) .map(x => (x % 337, x % 1331)) - .reduceByKey(_ + _) + .reduceByKey(_ + _, 100) .count() } } @@ -84,11 +84,11 @@ object WordCount2 { val ssc = new StreamingContext(master, "WordCount2") ssc.setBatchDuration(batchDuration) - //warmup(ssc.sc) + warmup(ssc.sc) val data = ssc.sc.textFile(file, mapTasks.toInt).persist( new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas - println("Data count: " + data.count()) + println("Data count: " + data.map(x => if (x == "") 1 else x.split(" ").size / x.split(" ").size).count()) println("Data count: " + data.count()) println("Data count: " + data.count()) -- cgit v1.2.3 From 75487b2f5a6abedd322520f759b814ec643aea01 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 08:14:50 +0000 Subject: Broadcast the JobConf in HadoopRDD to reduce task sizes --- core/src/main/scala/spark/HadoopRDD.scala | 5 +++-- core/src/main/scala/spark/KryoSerializer.scala | 4 ++++ 2 files changed, 7 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/HadoopRDD.scala b/core/src/main/scala/spark/HadoopRDD.scala index f282a4023b..0befca582d 100644 --- a/core/src/main/scala/spark/HadoopRDD.scala +++ b/core/src/main/scala/spark/HadoopRDD.scala @@ -42,7 +42,8 @@ class HadoopRDD[K, V]( minSplits: Int) extends RDD[(K, V)](sc) { - val serializableConf = new SerializableWritable(conf) + // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it + val confBroadcast = sc.broadcast(new SerializableWritable(conf)) @transient val splits_ : Array[Split] = { @@ -66,7 +67,7 @@ class HadoopRDD[K, V]( val split = theSplit.asInstanceOf[HadoopSplit] var reader: RecordReader[K, V] = null - val conf = serializableConf.value + val conf = confBroadcast.value.value val fmt = createInputFormat(conf) reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 65d0532bd5..3d042b2f11 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -10,6 +10,7 @@ import scala.collection.mutable import com.esotericsoftware.kryo._ import com.esotericsoftware.kryo.{Serializer => KSerializer} import com.esotericsoftware.kryo.serialize.ClassSerializer +import com.esotericsoftware.kryo.serialize.SerializableSerializer import de.javakaffee.kryoserializers.KryoReflectionFactorySupport import spark.storage._ @@ -203,6 +204,9 @@ class KryoSerializer extends Serializer with Logging { kryo.register(classOf[Class[_]], new ClassSerializer(kryo)) kryo.setRegistrationOptional(true) + // Allow sending SerializableWritable + kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer()) + // Register some commonly used Scala singleton objects. Because these // are singletons, we must return the exact same local object when we // deserialize rather than returning a clone as FieldSerializer would. -- cgit v1.2.3 From efc7668d16b2a58f8d074c1cdaeae4b37dae1c9c Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 08:22:57 +0000 Subject: Allow serializing HttpBroadcast through Kryo --- core/src/main/scala/spark/KryoSerializer.scala | 2 ++ 1 file changed, 2 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 3d042b2f11..8a3f565071 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -13,6 +13,7 @@ import com.esotericsoftware.kryo.serialize.ClassSerializer import com.esotericsoftware.kryo.serialize.SerializableSerializer import de.javakaffee.kryoserializers.KryoReflectionFactorySupport +import spark.broadcast._ import spark.storage._ /** @@ -206,6 +207,7 @@ class KryoSerializer extends Serializer with Logging { // Allow sending SerializableWritable kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer()) + kryo.register(classOf[HttpBroadcast[_]], new SerializableSerializer()) // Register some commonly used Scala singleton objects. Because these // are singletons, we must return the exact same local object when we -- cgit v1.2.3 From 3fa0d7f0c9883ab77e89b7bcf70b7b11df9a4184 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 08:28:15 +0000 Subject: Serialize BlockRDD more efficiently --- core/src/main/scala/spark/BlockRDD.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/BlockRDD.scala b/core/src/main/scala/spark/BlockRDD.scala index ea009f0f4f..daabc0d566 100644 --- a/core/src/main/scala/spark/BlockRDD.scala +++ b/core/src/main/scala/spark/BlockRDD.scala @@ -7,7 +7,8 @@ class BlockRDDSplit(val blockId: String, idx: Int) extends Split { } -class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) { +class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) + extends RDD[T](sc) { @transient val splits_ = (0 until blockIds.size).map(i => { -- cgit v1.2.3 From 1d6b36d3c3698090b35d8e7c4f88cac410f9ea01 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 22:26:37 +0000 Subject: Further tuning for network performance --- core/src/main/scala/spark/storage/BlockMessage.scala | 14 +------------- core/src/main/scala/spark/util/AkkaUtils.scala | 3 ++- 2 files changed, 3 insertions(+), 14 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala index 0b2ed69e07..607633c6df 100644 --- a/core/src/main/scala/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/spark/storage/BlockMessage.scala @@ -12,7 +12,7 @@ case class GetBlock(id: String) case class GotBlock(id: String, data: ByteBuffer) case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) -class BlockMessage() extends Logging{ +class BlockMessage() { // Un-initialized: typ = 0 // GetBlock: typ = 1 // GotBlock: typ = 2 @@ -22,8 +22,6 @@ class BlockMessage() extends Logging{ private var data: ByteBuffer = null private var level: StorageLevel = null - initLogging() - def set(getBlock: GetBlock) { typ = BlockMessage.TYPE_GET_BLOCK id = getBlock.id @@ -62,8 +60,6 @@ class BlockMessage() extends Logging{ } id = idBuilder.toString() - logDebug("Set from buffer Result: " + typ + " " + id) - logDebug("Buffer position is " + buffer.position) if (typ == BlockMessage.TYPE_PUT_BLOCK) { val booleanInt = buffer.getInt() @@ -77,23 +73,18 @@ class BlockMessage() extends Logging{ } data.put(buffer) data.flip() - logDebug("Set from buffer Result 2: " + level + " " + data) } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { val dataLength = buffer.getInt() - logDebug("Data length is "+ dataLength) - logDebug("Buffer position is " + buffer.position) data = ByteBuffer.allocate(dataLength) if (dataLength != buffer.remaining) { throw new Exception("Error parsing buffer") } data.put(buffer) data.flip() - logDebug("Set from buffer Result 3: " + data) } val finishTime = System.currentTimeMillis - logDebug("Converted " + id + " from bytebuffer in " + (finishTime - startTime) / 1000.0 + " s") } def set(bufferMsg: BufferMessage) { @@ -145,8 +136,6 @@ class BlockMessage() extends Logging{ buffers += data } - logDebug("Start to log buffers.") - buffers.foreach((x: ByteBuffer) => logDebug("" + x)) /* println() println("BlockMessage: ") @@ -160,7 +149,6 @@ class BlockMessage() extends Logging{ println() */ val finishTime = System.currentTimeMillis - logDebug("Converted " + id + " to buffer message in " + (finishTime - startTime) / 1000.0 + " s") return Message.createBufferMessage(buffers) } diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index fd64e224d7..330bb42e59 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -31,7 +31,8 @@ object AkkaUtils { akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = 1s - akka.remote.netty.execution-pool-size = 10 + akka.remote.netty.execution-pool-size = 4 + akka.actor.default-dispatcher.throughput = 20 """.format(host, port)) val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) -- cgit v1.2.3 From dc68febdce53efea43ee1ab91c05b14b7a5eae30 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 23:06:59 +0000 Subject: User Spark's closure serializer for the ShuffleMapTask cache --- core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 9 ++++----- .../main/scala/spark/scheduler/cluster/ClusterScheduler.scala | 1 + 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 73479bff01..f1eae9bc88 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -26,7 +26,8 @@ object ShuffleMapTask { return old } else { val out = new ByteArrayOutputStream - val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) + val ser = SparkEnv.get.closureSerializer.newInstance + val objOut = ser.serializeStream(new GZIPOutputStream(out)) objOut.writeObject(rdd) objOut.writeObject(dep) objOut.close() @@ -45,10 +46,8 @@ object ShuffleMapTask { } else { val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, loader) - } + val ser = SparkEnv.get.closureSerializer.newInstance + val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]] val tuple = (rdd, dep) diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 5b59479682..20c82ad0fa 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -115,6 +115,7 @@ class ClusterScheduler(sc: SparkContext) */ def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = { synchronized { + SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { slaveIdToHost(o.slaveId) = o.hostname -- cgit v1.2.3 From 215544820fe70274d9dce1410f61e2052b8bc406 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 23:54:04 +0000 Subject: Serialize map output locations more efficiently, and only once, in MapOutputTracker --- core/src/main/scala/spark/MapOutputTracker.scala | 88 +++++++++++++++++++++--- 1 file changed, 80 insertions(+), 8 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index de23eb6f48..cee2391c71 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -1,5 +1,6 @@ package spark +import java.io.{DataInputStream, DataOutputStream, ByteArrayOutputStream, ByteArrayInputStream} import java.util.concurrent.ConcurrentHashMap import akka.actor._ @@ -10,6 +11,7 @@ import akka.util.Duration import akka.util.Timeout import akka.util.duration._ +import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import spark.storage.BlockManagerId @@ -18,12 +20,11 @@ sealed trait MapOutputTrackerMessage case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage case object StopMapOutputTracker extends MapOutputTrackerMessage -class MapOutputTrackerActor(bmAddresses: ConcurrentHashMap[Int, Array[BlockManagerId]]) -extends Actor with Logging { +class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging { def receive = { case GetMapOutputLocations(shuffleId: Int) => logInfo("Asked to get map output locations for shuffle " + shuffleId) - sender ! bmAddresses.get(shuffleId) + sender ! tracker.getSerializedLocations(shuffleId) case StopMapOutputTracker => logInfo("MapOutputTrackerActor stopped!") @@ -39,15 +40,19 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg val timeout = 10.seconds - private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] + var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. private var generation: Long = 0 private var generationLock = new java.lang.Object + // Cache a serialized version of the output locations for each shuffle to send them out faster + var cacheGeneration = generation + val cachedSerializedLocs = new HashMap[Int, Array[Byte]] + var trackerActor: ActorRef = if (isMaster) { - val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(bmAddresses)), name = actorName) + val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName) logInfo("Registered MapOutputTrackerActor actor") actor } else { @@ -134,15 +139,16 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg } // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) - val fetched = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[BlockManagerId]] + val fetchedBytes = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[Byte]] + val fetchedLocs = deserializeLocations(fetchedBytes) logInfo("Got the output locations") - bmAddresses.put(shuffleId, fetched) + bmAddresses.put(shuffleId, fetchedLocs) fetching.synchronized { fetching -= shuffleId fetching.notifyAll() } - return fetched + return fetchedLocs } else { return locs } @@ -181,4 +187,70 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg } } } + + def getSerializedLocations(shuffleId: Int): Array[Byte] = { + var locs: Array[BlockManagerId] = null + var generationGotten: Long = -1 + generationLock.synchronized { + if (generation > cacheGeneration) { + cachedSerializedLocs.clear() + cacheGeneration = generation + } + cachedSerializedLocs.get(shuffleId) match { + case Some(bytes) => + return bytes + case None => + locs = bmAddresses.get(shuffleId) + generationGotten = generation + } + } + // If we got here, we failed to find the serialized locations in the cache, so we pulled + // out a snapshot of the locations as "locs"; let's serialize and return that + val bytes = serializeLocations(locs) + // Add them into the table only if the generation hasn't changed while we were working + generationLock.synchronized { + if (generation == generationGotten) { + cachedSerializedLocs(shuffleId) = bytes + } + } + return bytes + } + + // Serialize an array of map output locations into an efficient byte format so that we can send + // it to reduce tasks. We do this by grouping together the locations by block manager ID. + def serializeLocations(locs: Array[BlockManagerId]): Array[Byte] = { + val out = new ByteArrayOutputStream + val dataOut = new DataOutputStream(out) + dataOut.writeInt(locs.length) + val grouped = locs.zipWithIndex.groupBy(_._1) + dataOut.writeInt(grouped.size) + for ((id, pairs) <- grouped) { + dataOut.writeUTF(id.ip) + dataOut.writeInt(id.port) + dataOut.writeInt(pairs.length) + for ((_, blockIndex) <- pairs) { + dataOut.writeInt(blockIndex) + } + } + dataOut.close() + out.toByteArray + } + + // Opposite of serializeLocations. + def deserializeLocations(bytes: Array[Byte]): Array[BlockManagerId] = { + val dataIn = new DataInputStream(new ByteArrayInputStream(bytes)) + val length = dataIn.readInt() + val array = new Array[BlockManagerId](length) + val numGroups = dataIn.readInt() + for (i <- 0 until numGroups) { + val ip = dataIn.readUTF() + val port = dataIn.readInt() + val id = new BlockManagerId(ip, port) + val numBlocks = dataIn.readInt() + for (j <- 0 until numBlocks) { + array(dataIn.readInt()) = id + } + } + array + } } -- cgit v1.2.3 From 2fa6d999fd92cb7ce828278edcd09eecd1f458c1 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Sep 2012 00:16:39 +0000 Subject: Tuning Akka more --- core/src/main/scala/spark/util/AkkaUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 330bb42e59..df4e23bfd6 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -31,8 +31,8 @@ object AkkaUtils { akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = 1s - akka.remote.netty.execution-pool-size = 4 - akka.actor.default-dispatcher.throughput = 20 + akka.remote.netty.execution-pool-size = 8 + akka.actor.default-dispatcher.throughput = 30 """.format(host, port)) val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) -- cgit v1.2.3 From 9ef90c95f4947e47f7c44f952ff8d294e0932a73 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Sep 2012 00:43:46 +0000 Subject: Bug fix --- core/src/main/scala/spark/MapOutputTracker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index cee2391c71..82c1391345 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -224,7 +224,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg dataOut.writeInt(locs.length) val grouped = locs.zipWithIndex.groupBy(_._1) dataOut.writeInt(grouped.size) - for ((id, pairs) <- grouped) { + for ((id, pairs) <- grouped if id != null) { dataOut.writeUTF(id.ip) dataOut.writeInt(id.port) dataOut.writeInt(pairs.length) -- cgit v1.2.3 From db08a362aae68682f9105f9e5568bc9b9d9faaab Mon Sep 17 00:00:00 2001 From: haoyuan Date: Fri, 7 Sep 2012 02:17:52 +0000 Subject: commit opt for grep scalibility test. --- .../main/scala/spark/storage/BlockManager.scala | 7 +++- .../spark/streaming/NetworkInputTracker.scala | 40 ++++++++++++---------- .../scala/spark/streaming/RawInputDStream.scala | 17 +++++---- .../spark/streaming/examples/WordCountRaw.scala | 19 +++++++--- 4 files changed, 51 insertions(+), 32 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index f2d9499bad..4cdb9710ec 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -509,10 +509,15 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * Replicate block to another node. */ + var firstTime = true + var peers : Seq[BlockManagerId] = null private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) - var peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + if (firstTime) { + peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + firstTime = false; + } for (peer: BlockManagerId <- peers) { val start = System.nanoTime data.rewind() diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index acf97c1883..9f9001e4d5 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -4,6 +4,7 @@ import spark.Logging import spark.SparkEnv import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue import akka.actor._ import akka.pattern.ask @@ -28,6 +29,17 @@ extends Logging { logInfo("Registered receiver for network stream " + streamId) sender ! true } + case GotBlockIds(streamId, blockIds) => { + val tmp = receivedBlockIds.synchronized { + if (!receivedBlockIds.contains(streamId)) { + receivedBlockIds += ((streamId, new Queue[String])) + } + receivedBlockIds(streamId) + } + tmp.synchronized { + tmp ++= blockIds + } + } } } @@ -69,8 +81,8 @@ extends Logging { val networkInputStreamIds = networkInputStreams.map(_.id).toArray val receiverExecutor = new ReceiverExecutor() val receiverInfo = new HashMap[Int, ActorRef] - val receivedBlockIds = new HashMap[Int, Array[String]] - val timeout = 1000.milliseconds + val receivedBlockIds = new HashMap[Int, Queue[String]] + val timeout = 5000.milliseconds var currentTime: Time = null @@ -86,22 +98,12 @@ extends Logging { } def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { - if (currentTime == null || time > currentTime) { - logInfo("Getting block ids from receivers for " + time) - implicit val ec = ssc.env.actorSystem.dispatcher - receivedBlockIds.clear() - val message = new GetBlockIds(time) - val listOfFutures = receiverInfo.values.map( - _.ask(message)(timeout).mapTo[GotBlockIds] - ).toList - val futureOfList = Future.sequence(listOfFutures) - val allBlockIds = Await.result(futureOfList, timeout) - receivedBlockIds ++= allBlockIds.map(x => (x.streamId, x.blocksIds)) - if (receivedBlockIds.size != receiverInfo.size) { - throw new Exception("Unexpected number of the Block IDs received") - } - currentTime = time + val queue = receivedBlockIds.synchronized { + receivedBlockIds.getOrElse(receiverId, new Queue[String]()) + } + val result = queue.synchronized { + queue.dequeueAll(x => true) } - receivedBlockIds.getOrElse(receiverId, Array[String]()) + result.toArray } -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index d59c245a23..d29aea7886 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -86,14 +86,15 @@ class RawInputDStream[T: ClassManifest]( private class ReceiverActor(env: SparkEnv, receivingThread: Thread) extends Actor { val newBlocks = new ArrayBuffer[String] + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val actorName: String = "NetworkInputTracker" + val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) + val trackerActor = env.actorSystem.actorFor(url) + val timeout = 5.seconds + override def preStart() { - logInfo("Attempting to register with tracker") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "NetworkInputTracker" - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - val trackerActor = env.actorSystem.actorFor(url) - val timeout = 1.seconds val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) Await.result(future, timeout) } @@ -101,6 +102,7 @@ class RawInputDStream[T: ClassManifest]( override def receive = { case BlockPublished(blockId) => newBlocks += blockId + val future = trackerActor ! GotBlockIds(streamId, Array(blockId)) case GetBlockIds(time) => logInfo("Got request for block IDs for " + time) @@ -111,5 +113,6 @@ class RawInputDStream[T: ClassManifest]( receivingThread.interrupt() sender ! true } + } } diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index 298d9ef381..9702003805 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -1,11 +1,24 @@ package spark.streaming.examples import spark.util.IntParam +import spark.SparkContext +import spark.SparkContext._ import spark.storage.StorageLevel import spark.streaming._ import spark.streaming.StreamingContext._ +import WordCount2_ExtraFunctions._ + object WordCountRaw { + def moreWarmup(sc: SparkContext) { + (0 until 40).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(_ % 1331).map(_.toString) + .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + .collect() + } + } + def main(args: Array[String]) { if (args.length != 7) { System.err.println("Usage: WordCountRaw ") @@ -20,16 +33,12 @@ object WordCountRaw { ssc.setBatchDuration(Milliseconds(batchMs)) // Make sure some tasks have started on each node - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() + moreWarmup(ssc.sc) val rawStreams = (1 to streams).map(_ => ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray val union = new UnifiedDStream(rawStreams) - import WordCount2_ExtraFunctions._ - val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, -- cgit v1.2.3 From c63a6064584ea19d62e0abcbd3886d7b1e429ea1 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 10 Sep 2012 19:51:27 +0000 Subject: Made NewHadoopRDD broadcast its job configuration (same as HadoopRDD). --- core/src/main/scala/spark/NewHadoopRDD.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/NewHadoopRDD.scala b/core/src/main/scala/spark/NewHadoopRDD.scala index d024d38aa9..14f708a3f8 100644 --- a/core/src/main/scala/spark/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/NewHadoopRDD.scala @@ -28,7 +28,9 @@ class NewHadoopRDD[K, V]( @transient conf: Configuration) extends RDD[(K, V)](sc) { - private val serializableConf = new SerializableWritable(conf) + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + // private val serializableConf = new SerializableWritable(conf) private val jobtrackerId: String = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") @@ -41,7 +43,7 @@ class NewHadoopRDD[K, V]( @transient private val splits_ : Array[Split] = { val inputFormat = inputFormatClass.newInstance - val jobContext = new JobContext(serializableConf.value, jobId) + val jobContext = new JobContext(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Split](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -54,9 +56,9 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Split) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopSplit] - val conf = serializableConf.value + val conf = confBroadcast.value.value val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) - val context = new TaskAttemptContext(serializableConf.value, attemptId) + val context = new TaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance val reader = format.createRecordReader(split.serializableHadoopSplit.value, context) reader.initialize(split.serializableHadoopSplit.value, context) -- cgit v1.2.3 From e95ff45b53bf995d89f1825b9581cc18a083a438 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 13 Oct 2012 20:10:49 -0700 Subject: Implemented checkpointing of StreamingContext and DStream graph. --- core/src/main/scala/spark/SparkContext.scala | 4 +- .../main/scala/spark/streaming/Checkpoint.scala | 92 +++++++++++++++ .../src/main/scala/spark/streaming/DStream.scala | 123 ++++++++++++++------- .../main/scala/spark/streaming/DStreamGraph.scala | 80 ++++++++++++++ .../scala/spark/streaming/FileInputDStream.scala | 59 ++++++---- .../spark/streaming/ReducedWindowedDStream.scala | 80 +++++++------- .../src/main/scala/spark/streaming/Scheduler.scala | 33 +++--- .../main/scala/spark/streaming/StateDStream.scala | 20 ++-- .../scala/spark/streaming/StreamingContext.scala | 109 ++++++++++++------ .../examples/FileStreamWithCheckpoint.scala | 76 +++++++++++++ .../scala/spark/streaming/examples/Grep2.scala | 2 +- .../spark/streaming/examples/WordCount2.scala | 2 +- .../scala/spark/streaming/examples/WordMax2.scala | 2 +- .../spark/streaming/util/RecurringTimer.scala | 19 +++- 14 files changed, 536 insertions(+), 165 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/Checkpoint.scala create mode 100644 streaming/src/main/scala/spark/streaming/DStreamGraph.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bebebe8262..1d5131ad13 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -46,8 +46,8 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend import spark.storage.BlockManagerMaster class SparkContext( - master: String, - frameworkName: String, + val master: String, + val frameworkName: String, val sparkHome: String, val jars: Seq[String]) extends Logging { diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala new file mode 100644 index 0000000000..3bd8fd5a27 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -0,0 +1,92 @@ +package spark.streaming + +import spark.Utils + +import org.apache.hadoop.fs.{FileUtil, Path} +import org.apache.hadoop.conf.Configuration + +import java.io.{ObjectInputStream, ObjectOutputStream} + +class Checkpoint(@transient ssc: StreamingContext) extends Serializable { + val master = ssc.sc.master + val frameworkName = ssc.sc.frameworkName + val sparkHome = ssc.sc.sparkHome + val jars = ssc.sc.jars + val graph = ssc.graph + val batchDuration = ssc.batchDuration + val checkpointFile = ssc.checkpointFile + val checkpointInterval = ssc.checkpointInterval + + def saveToFile(file: String) { + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (fs.exists(path)) { + val bkPath = new Path(path.getParent, path.getName + ".bk") + FileUtil.copy(fs, path, fs, bkPath, true, true, conf) + println("Moved existing checkpoint file to " + bkPath) + } + val fos = fs.create(path) + val oos = new ObjectOutputStream(fos) + oos.writeObject(this) + oos.close() + fs.close() + } + + def toBytes(): Array[Byte] = { + val cp = new Checkpoint(ssc) + val bytes = Utils.serialize(cp) + bytes + } +} + +object Checkpoint { + + def loadFromFile(file: String): Checkpoint = { + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (!fs.exists(path)) { + throw new Exception("Could not read checkpoint file " + path) + } + val fis = fs.open(path) + val ois = new ObjectInputStream(fis) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp + } + + def fromBytes(bytes: Array[Byte]): Checkpoint = { + Utils.deserialize[Checkpoint](bytes) + } + + /*def toBytes(ssc: StreamingContext): Array[Byte] = { + val cp = new Checkpoint(ssc) + val bytes = Utils.serialize(cp) + bytes + } + + + def saveContext(ssc: StreamingContext, file: String) { + val cp = new Checkpoint(ssc) + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (fs.exists(path)) { + val bkPath = new Path(path.getParent, path.getName + ".bk") + FileUtil.copy(fs, path, fs, bkPath, true, true, conf) + println("Moved existing checkpoint file to " + bkPath) + } + val fos = fs.create(path) + val oos = new ObjectOutputStream(fos) + oos.writeObject(cp) + oos.close() + fs.close() + } + + def loadContext(file: String): StreamingContext = { + loadCheckpoint(file).createNewContext() + } + */ +} diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 7e8098c346..78e4c57647 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -2,20 +2,19 @@ package spark.streaming import spark.streaming.StreamingContext._ -import spark.RDD -import spark.UnionRDD -import spark.Logging +import spark._ import spark.SparkContext._ import spark.storage.StorageLevel -import spark.Partitioner import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import java.util.concurrent.ArrayBlockingQueue +import java.io.{ObjectInputStream, IOException, ObjectOutputStream} +import scala.Some -abstract class DStream[T: ClassManifest] (@transient val ssc: StreamingContext) -extends Logging with Serializable { +abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) +extends Serializable with Logging { initLogging() @@ -41,10 +40,10 @@ extends Logging with Serializable { */ // Variable to store the RDDs generated earlier in time - @transient protected val generatedRDDs = new HashMap[Time, RDD[T]] () + protected val generatedRDDs = new HashMap[Time, RDD[T]] () // Variable to be set to the first time seen by the DStream (effective time zero) - protected[streaming] var zeroTime: Time = null + protected var zeroTime: Time = null // Variable to specify storage level protected var storageLevel: StorageLevel = StorageLevel.NONE @@ -53,6 +52,9 @@ extends Logging with Serializable { protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint protected var checkpointInterval: Time = null + // Reference to whole DStream graph, so that checkpointing process can lock it + protected var graph: DStreamGraph = null + // Change this RDD's storage level def persist( storageLevel: StorageLevel, @@ -77,7 +79,7 @@ extends Logging with Serializable { // Turn on the default caching level for this RDD def cache(): DStream[T] = persist() - def isInitialized = (zeroTime != null) + def isInitialized() = (zeroTime != null) /** * This method initializes the DStream by setting the "zero" time, based on which @@ -85,15 +87,33 @@ extends Logging with Serializable { * its parent DStreams. */ protected[streaming] def initialize(time: Time) { - if (zeroTime == null) { - zeroTime = time + if (zeroTime != null) { + throw new Exception("ZeroTime is already initialized, cannot initialize it again") } + zeroTime = time logInfo(this + " initialized") dependencies.foreach(_.initialize(zeroTime)) } + protected[streaming] def setContext(s: StreamingContext) { + if (ssc != null && ssc != s) { + throw new Exception("Context is already set, cannot set it again") + } + ssc = s + logInfo("Set context for " + this.getClass.getSimpleName) + dependencies.foreach(_.setContext(ssc)) + } + + protected[streaming] def setGraph(g: DStreamGraph) { + if (graph != null && graph != g) { + throw new Exception("Graph is already set, cannot set it again") + } + graph = g + dependencies.foreach(_.setGraph(graph)) + } + /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ - protected def isTimeValid (time: Time): Boolean = { + protected def isTimeValid(time: Time): Boolean = { if (!isInitialized) { throw new Exception (this.toString + " has not been initialized") } else if (time < zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) { @@ -158,13 +178,42 @@ extends Logging with Serializable { } } + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + println(this.getClass().getSimpleName + ".writeObject used") + if (graph != null) { + graph.synchronized { + if (graph.checkpointInProgress) { + oos.defaultWriteObject() + } else { + val msg = "Object of " + this.getClass.getName + " is being serialized " + + " possibly as a part of closure of an RDD operation. This is because " + + " the DStream object is being referred to from within the closure. " + + " Please rewrite the RDD operation inside this DStream to avoid this. " + + " This has been enforced to avoid bloating of Spark tasks " + + " with unnecessary objects." + throw new java.io.NotSerializableException(msg) + } + } + } else { + throw new java.io.NotSerializableException("Graph is unexpectedly null when DStream is being serialized.") + } + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + println(this.getClass().getSimpleName + ".readObject used") + ois.defaultReadObject() + } + /** * -------------- * DStream operations * -------------- */ - - def map[U: ClassManifest](mapFunc: T => U) = new MappedDStream(this, ssc.sc.clean(mapFunc)) + def map[U: ClassManifest](mapFunc: T => U) = { + new MappedDStream(this, ssc.sc.clean(mapFunc)) + } def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = { new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) @@ -262,19 +311,15 @@ extends Logging with Serializable { // Get all the RDDs between fromTime to toTime (both included) def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { - val rdds = new ArrayBuffer[RDD[T]]() var time = toTime.floor(slideTime) - - while (time >= zeroTime && time >= fromTime) { getOrCompute(time) match { case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not get old reduced RDD for time " + time) + case None => //throw new Exception("Could not get RDD for time " + time) } time -= slideTime } - rdds.toSeq } @@ -284,12 +329,16 @@ extends Logging with Serializable { } -abstract class InputDStream[T: ClassManifest] (@transient ssc: StreamingContext) - extends DStream[T](ssc) { +abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) + extends DStream[T](ssc_) { override def dependencies = List() - override def slideTime = ssc.batchDuration + override def slideTime = { + if (ssc == null) throw new Exception("ssc is null") + if (ssc.batchDuration == null) throw new Exception("ssc.batchDuration is null") + ssc.batchDuration + } def start() @@ -302,7 +351,7 @@ abstract class InputDStream[T: ClassManifest] (@transient ssc: StreamingContext) */ class MappedDStream[T: ClassManifest, U: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], mapFunc: T => U ) extends DStream[U](parent.ssc) { @@ -321,7 +370,7 @@ class MappedDStream[T: ClassManifest, U: ClassManifest] ( */ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( - @transient parent: DStream[T], + parent: DStream[T], flatMapFunc: T => Traversable[U] ) extends DStream[U](parent.ssc) { @@ -340,7 +389,7 @@ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( */ class FilteredDStream[T: ClassManifest]( - @transient parent: DStream[T], + parent: DStream[T], filterFunc: T => Boolean ) extends DStream[T](parent.ssc) { @@ -359,7 +408,7 @@ class FilteredDStream[T: ClassManifest]( */ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( - @transient parent: DStream[T], + parent: DStream[T], mapPartFunc: Iterator[T] => Iterator[U] ) extends DStream[U](parent.ssc) { @@ -377,7 +426,7 @@ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( * TODO */ -class GlommedDStream[T: ClassManifest](@transient parent: DStream[T]) +class GlommedDStream[T: ClassManifest](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { override def dependencies = List(parent) @@ -395,7 +444,7 @@ class GlommedDStream[T: ClassManifest](@transient parent: DStream[T]) */ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( - @transient parent: DStream[(K,V)], + parent: DStream[(K,V)], createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, @@ -420,7 +469,7 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( * TODO */ -class UnifiedDStream[T: ClassManifest](@transient parents: Array[DStream[T]]) +class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]]) extends DStream[T](parents(0).ssc) { if (parents.length == 0) { @@ -459,7 +508,7 @@ class UnifiedDStream[T: ClassManifest](@transient parents: Array[DStream[T]]) */ class PerElementForEachDStream[T: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], foreachFunc: T => Unit ) extends DStream[Unit](parent.ssc) { @@ -490,7 +539,7 @@ class PerElementForEachDStream[T: ClassManifest] ( */ class PerRDDForEachDStream[T: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], foreachFunc: (RDD[T], Time) => Unit ) extends DStream[Unit](parent.ssc) { @@ -518,15 +567,15 @@ class PerRDDForEachDStream[T: ClassManifest] ( */ class TransformedDStream[T: ClassManifest, U: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], transformFunc: (RDD[T], Time) => RDD[U] ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Time = parent.slideTime - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(transformFunc(_, validTime)) - } + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(transformFunc(_, validTime)) } +} diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala new file mode 100644 index 0000000000..67859e0131 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -0,0 +1,80 @@ +package spark.streaming + +import java.io.{ObjectInputStream, IOException, ObjectOutputStream} +import collection.mutable.ArrayBuffer + +final class DStreamGraph extends Serializable { + + private val inputStreams = new ArrayBuffer[InputDStream[_]]() + private val outputStreams = new ArrayBuffer[DStream[_]]() + + private[streaming] var zeroTime: Time = null + private[streaming] var checkpointInProgress = false; + + def started() = (zeroTime != null) + + def start(time: Time) { + this.synchronized { + if (started) { + throw new Exception("DStream graph computation already started") + } + zeroTime = time + outputStreams.foreach(_.initialize(zeroTime)) + inputStreams.par.foreach(_.start()) + } + + } + + def stop() { + this.synchronized { + inputStreams.par.foreach(_.stop()) + } + } + + private[streaming] def setContext(ssc: StreamingContext) { + this.synchronized { + outputStreams.foreach(_.setContext(ssc)) + } + } + + def addInputStream(inputStream: InputDStream[_]) { + inputStream.setGraph(this) + inputStreams += inputStream + } + + def addOutputStream(outputStream: DStream[_]) { + outputStream.setGraph(this) + outputStreams += outputStream + } + + def getInputStreams() = inputStreams.toArray + + def getOutputStreams() = outputStreams.toArray + + def generateRDDs(time: Time): Seq[Job] = { + this.synchronized { + outputStreams.flatMap(outputStream => outputStream.generateJob(time)) + } + } + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + this.synchronized { + checkpointInProgress = true + oos.defaultWriteObject() + checkpointInProgress = false + } + println("DStreamGraph.writeObject used") + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + this.synchronized { + checkpointInProgress = true + ois.defaultReadObject() + checkpointInProgress = false + } + println("DStreamGraph.readObject used") + } +} + diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala index 96a64f0018..29ae89616e 100644 --- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala @@ -1,33 +1,45 @@ package spark.streaming -import spark.SparkContext import spark.RDD -import spark.BlockRDD import spark.UnionRDD -import spark.storage.StorageLevel -import spark.streaming._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import java.net.InetSocketAddress - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.PathFilter +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import java.io.{ObjectInputStream, IOException} class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( - ssc: StreamingContext, - directory: Path, + @transient ssc_ : StreamingContext, + directory: String, filter: PathFilter = FileInputDStream.defaultPathFilter, newFilesOnly: Boolean = true) - extends InputDStream[(K, V)](ssc) { - - val fs = directory.getFileSystem(new Configuration()) + extends InputDStream[(K, V)](ssc_) { + + @transient private var path_ : Path = null + @transient private var fs_ : FileSystem = null + + /* + @transient @noinline lazy val path = { + //if (directory == null) throw new Exception("directory is null") + //println(directory) + new Path(directory) + } + @transient lazy val fs = path.getFileSystem(new Configuration()) + */ + var lastModTime: Long = 0 - + + def path(): Path = { + if (path_ == null) path_ = new Path(directory) + path_ + } + + def fs(): FileSystem = { + if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + fs_ + } + override def start() { if (newFilesOnly) { lastModTime = System.currentTimeMillis() @@ -58,7 +70,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K } } - val newFiles = fs.listStatus(directory, newFilter) + val newFiles = fs.listStatus(path, newFilter) logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) if (newFiles.length > 0) { lastModTime = newFilter.latestModTime @@ -67,10 +79,19 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) Some(newRDD) } + /* + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + println(this.getClass().getSimpleName + ".readObject used") + ois.defaultReadObject() + println("HERE HERE" + this.directory) + } + */ + } object FileInputDStream { - val defaultPathFilter = new PathFilter { + val defaultPathFilter = new PathFilter with Serializable { def accept(path: Path): Boolean = { val file = path.getName() if (file.startsWith(".") || file.endsWith("_tmp")) { diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index b0beaba94d..e161b5ba92 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -10,9 +10,10 @@ import spark.SparkContext._ import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer +import collection.SeqProxy class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( - @transient parent: DStream[(K, V)], + parent: DStream[(K, V)], reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, _windowTime: Time, @@ -46,6 +47,8 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( } override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val reduceF = reduceFunc + val invReduceF = invReduceFunc val currentTime = validTime val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime) @@ -84,54 +87,47 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( // Cogroup the reduced RDDs and merge the reduced values val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) - val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ - val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValuesFunc) + //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ - Some(mergedValuesRDD) - } - - def mergeValues(numOldValues: Int, numNewValues: Int)(seqOfValues: Seq[Seq[V]]): V = { - - if (seqOfValues.size != 1 + numOldValues + numNewValues) { - throw new Exception("Unexpected number of sequences of reduced values") - } - - // Getting reduced values "old time steps" that will be removed from current window - val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) - - // Getting reduced values "new time steps" - val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) - - if (seqOfValues(0).isEmpty) { + val numOldValues = oldRDDs.size + val numNewValues = newRDDs.size - // If previous window's reduce value does not exist, then at least new values should exist - if (newValues.isEmpty) { - throw new Exception("Neither previous window has value for key, nor new values found") + val mergeValues = (seqOfValues: Seq[Seq[V]]) => { + if (seqOfValues.size != 1 + numOldValues + numNewValues) { + throw new Exception("Unexpected number of sequences of reduced values") } + // Getting reduced values "old time steps" that will be removed from current window + val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) + // Getting reduced values "new time steps" + val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + if (seqOfValues(0).isEmpty) { + // If previous window's reduce value does not exist, then at least new values should exist + if (newValues.isEmpty) { + throw new Exception("Neither previous window has value for key, nor new values found") + } + // Reduce the new values + newValues.reduce(reduceF) // return + } else { + // Get the previous window's reduced value + var tempValue = seqOfValues(0).head + // If old values exists, then inverse reduce then from previous value + if (!oldValues.isEmpty) { + tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) + } + // If new values exists, then reduce them with previous value + if (!newValues.isEmpty) { + tempValue = reduceF(tempValue, newValues.reduce(reduceF)) + } + tempValue // return + } + } - // Reduce the new values - // println("new values = " + newValues.map(_.toString).reduce(_ + " " + _)) - return newValues.reduce(reduceFunc) - } else { + val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) - // Get the previous window's reduced value - var tempValue = seqOfValues(0).head + Some(mergedValuesRDD) + } - // If old values exists, then inverse reduce then from previous value - if (!oldValues.isEmpty) { - // println("old values = " + oldValues.map(_.toString).reduce(_ + " " + _)) - tempValue = invReduceFunc(tempValue, oldValues.reduce(reduceFunc)) - } - // If new values exists, then reduce them with previous value - if (!newValues.isEmpty) { - // println("new values = " + newValues.map(_.toString).reduce(_ + " " + _)) - tempValue = reduceFunc(tempValue, newValues.reduce(reduceFunc)) - } - // println("final value = " + tempValue) - return tempValue - } - } } diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index d2e907378d..d62b7e7140 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -11,45 +11,44 @@ import scala.collection.mutable.HashMap sealed trait SchedulerMessage case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage -class Scheduler( - ssc: StreamingContext, - inputStreams: Array[InputDStream[_]], - outputStreams: Array[DStream[_]]) +class Scheduler(ssc: StreamingContext) extends Logging { initLogging() + val graph = ssc.graph val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] val timer = new RecurringTimer(clock, ssc.batchDuration, generateRDDs(_)) - + + def start() { - val zeroTime = Time(timer.start()) - outputStreams.foreach(_.initialize(zeroTime)) - inputStreams.par.foreach(_.start()) + if (graph.started) { + timer.restart(graph.zeroTime.milliseconds) + } else { + val zeroTime = Time(timer.start()) + graph.start(zeroTime) + } logInfo("Scheduler started") } def stop() { timer.stop() - inputStreams.par.foreach(_.stop()) + graph.stop() logInfo("Scheduler stopped") } - def generateRDDs (time: Time) { + def generateRDDs(time: Time) { SparkEnv.set(ssc.env) logInfo("\n-----------------------------------------------------\n") logInfo("Generating RDDs for time " + time) - outputStreams.foreach(outputStream => { - outputStream.generateJob(time) match { - case Some(job) => submitJob(job) - case None => - } - } - ) + graph.generateRDDs(time).foreach(submitJob) logInfo("Generated RDDs for time " + time) + if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { + ssc.checkpoint() + } } def submitJob(job: Job) { diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index c40f70c91d..d223f25dfc 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -7,6 +7,12 @@ import spark.MapPartitionsRDD import spark.SparkContext._ import spark.storage.StorageLevel + +class StateRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: Iterator[T] => Iterator[U], rememberPartitioner: Boolean) + extends MapPartitionsRDD[U, T](prev, f) { + override val partitioner = if (rememberPartitioner) prev.partitioner else None +} + class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( @transient parent: DStream[(K, V)], updateFunc: (Iterator[(K, Seq[V], S)]) => Iterator[(K, S)], @@ -14,11 +20,6 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife rememberPartitioner: Boolean ) extends DStream[(K, S)](parent.ssc) { - class SpecialMapPartitionsRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: Iterator[T] => Iterator[U]) - extends MapPartitionsRDD(prev, f) { - override val partitioner = if (rememberPartitioner) prev.partitioner else None - } - override def dependencies = List(parent) override def slideTime = parent.slideTime @@ -79,19 +80,18 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife // first map the cogrouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc - val mapPartitionFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { val i = iterator.map(t => { (t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S])) }) updateFuncLocal(i) } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) - val stateRDD = new SpecialMapPartitionsRDD(cogroupedRDD, mapPartitionFunc) + val stateRDD = new StateRDD(cogroupedRDD, finalFunc, rememberPartitioner) //logDebug("Generating state RDD for time " + validTime) return Some(stateRDD) } case None => { // If parent RDD does not exist, then return old state RDD - //logDebug("Generating state RDD for time " + validTime + " (no change)") return Some(prevStateRDD) } } @@ -107,12 +107,12 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife // first map the grouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc - val mapPartitionFunc = (iterator: Iterator[(K, Seq[V])]) => { + val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S]))) } val groupedRDD = parentRDD.groupByKey(partitioner) - val sessionRDD = new SpecialMapPartitionsRDD(groupedRDD, mapPartitionFunc) + val sessionRDD = new StateRDD(groupedRDD, finalFunc, rememberPartitioner) //logDebug("Generating state RDD for time " + validTime + " (first)") return Some(sessionRDD) } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 12f3626680..1499ef4ea2 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -21,31 +21,70 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat -class StreamingContext (@transient val sc: SparkContext) extends Logging { +class StreamingContext ( + sc_ : SparkContext, + cp_ : Checkpoint + ) extends Logging { + + def this(sparkContext: SparkContext) = this(sparkContext, null) def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = - this(new SparkContext(master, frameworkName, sparkHome, jars)) + this(new SparkContext(master, frameworkName, sparkHome, jars), null) + + def this(file: String) = this(null, Checkpoint.loadFromFile(file)) + + def this(cp_ : Checkpoint) = this(null, cp_) initLogging() + if (sc_ == null && cp_ == null) { + throw new Exception("Streaming Context cannot be initilalized with " + + "both SparkContext and checkpoint as null") + } + + val isCheckpointPresent = (cp_ != null) + + val sc: SparkContext = { + if (isCheckpointPresent) { + new SparkContext(cp_.master, cp_.frameworkName, cp_.sparkHome, cp_.jars) + } else { + sc_ + } + } + val env = SparkEnv.get - - val inputStreams = new ArrayBuffer[InputDStream[_]]() - val outputStreams = new ArrayBuffer[DStream[_]]() + + val graph: DStreamGraph = { + if (isCheckpointPresent) { + + cp_.graph.setContext(this) + cp_.graph + } else { + new DStreamGraph() + } + } + val nextNetworkInputStreamId = new AtomicInteger(0) - var batchDuration: Time = null - var scheduler: Scheduler = null + var batchDuration: Time = if (isCheckpointPresent) cp_.batchDuration else null + var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null + var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null var networkInputTracker: NetworkInputTracker = null - var receiverJobThread: Thread = null - - def setBatchDuration(duration: Long) { - setBatchDuration(Time(duration)) - } - + var receiverJobThread: Thread = null + var scheduler: Scheduler = null + def setBatchDuration(duration: Time) { + if (batchDuration != null) { + throw new Exception("Batch duration alread set as " + batchDuration + + ". cannot set it again.") + } batchDuration = duration } + + def setCheckpointDetails(file: String, interval: Time) { + checkpointFile = file + checkpointInterval = interval + } private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() @@ -59,7 +98,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { converter: (InputStream) => Iterator[T] ): DStream[T] = { val inputStream = new ObjectInputDStream[T](this, hostname, port, converter) - inputStreams += inputStream + graph.addInputStream(inputStream) inputStream } @@ -69,7 +108,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_2 ): DStream[T] = { val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel) - inputStreams += inputStream + graph.addInputStream(inputStream) inputStream } @@ -94,8 +133,8 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { V: ClassManifest, F <: NewInputFormat[K, V]: ClassManifest ](directory: String): DStream[(K, V)] = { - val inputStream = new FileInputDStream[K, V, F](this, new Path(directory)) - inputStreams += inputStream + val inputStream = new FileInputDStream[K, V, F](this, directory) + graph.addInputStream(inputStream) inputStream } @@ -113,24 +152,31 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { defaultRDD: RDD[T] = null ): DStream[T] = { val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) - inputStreams += inputStream + graph.addInputStream(inputStream) inputStream } - def createQueueStream[T: ClassManifest](iterator: Iterator[RDD[T]]): DStream[T] = { + def createQueueStream[T: ClassManifest](iterator: Array[RDD[T]]): DStream[T] = { val queue = new Queue[RDD[T]] val inputStream = createQueueStream(queue, true, null) queue ++= iterator inputStream - } + } + + /** + * This function registers a InputDStream as an input stream that will be + * started (InputDStream.start() called) to get the input data streams. + */ + def registerInputStream(inputStream: InputDStream[_]) { + graph.addInputStream(inputStream) + } - /** * This function registers a DStream as an output stream that will be * computed every interval. */ - def registerOutputStream (outputStream: DStream[_]) { - outputStreams += outputStream + def registerOutputStream(outputStream: DStream[_]) { + graph.addOutputStream(outputStream) } /** @@ -143,13 +189,9 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { if (batchDuration < Milliseconds(100)) { logWarning("Batch duration of " + batchDuration + " is very low") } - if (inputStreams.size == 0) { - throw new Exception("No input streams created, so nothing to take input from") - } - if (outputStreams.size == 0) { + if (graph.getOutputStreams().size == 0) { throw new Exception("No output streams registered, so nothing to execute") } - } /** @@ -157,7 +199,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { */ def start() { verify() - val networkInputStreams = inputStreams.filter(s => s match { + val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true case _ => false }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray @@ -169,8 +211,9 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { } Thread.sleep(1000) - // Start the scheduler - scheduler = new Scheduler(this, inputStreams.toArray, outputStreams.toArray) + + // Start the scheduler + scheduler = new Scheduler(this) scheduler.start() } @@ -189,6 +232,10 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { logInfo("StreamingContext stopped") } + + def checkpoint() { + new Checkpoint(this).saveToFile(checkpointFile) + } } diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala new file mode 100644 index 0000000000..c725035a8a --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala @@ -0,0 +1,76 @@ +package spark.streaming.examples + +import spark.streaming._ +import spark.streaming.StreamingContext._ +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + +object FileStreamWithCheckpoint { + + def main(args: Array[String]) { + + if (args.size != 3) { + println("FileStreamWithCheckpoint ") + println("FileStreamWithCheckpoint restart ") + System.exit(-1) + } + + val directory = new Path(args(1)) + val checkpointFile = args(2) + + val ssc: StreamingContext = { + + if (args(0) == "restart") { + + // Recreated streaming context from specified checkpoint file + new StreamingContext(checkpointFile) + + } else { + + // Create directory if it does not exist + val fs = directory.getFileSystem(new Configuration()) + if (!fs.exists(directory)) fs.mkdirs(directory) + + // Create new streaming context + val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint") + ssc_.setBatchDuration(Seconds(1)) + ssc_.setCheckpointDetails(checkpointFile, Seconds(1)) + + // Setup the streaming computation + val inputStream = ssc_.createTextFileStream(directory.toString) + val words = inputStream.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + + ssc_ + } + } + + // Start the stream computation + startFileWritingThread(directory.toString) + ssc.start() + } + + def startFileWritingThread(directory: String) { + + val fs = new Path(directory).getFileSystem(new Configuration()) + + val fileWritingThread = new Thread() { + override def run() { + val r = new scala.util.Random() + val text = "This is a sample text file with a random number " + while(true) { + val number = r.nextInt() + val file = new Path(directory, number.toString) + val fos = fs.create(file) + fos.writeChars(text + number) + fos.close() + println("Created text file " + file) + Thread.sleep(1000) + } + } + } + fileWritingThread.start() + } + +} diff --git a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala index 7237142c7c..b1faa65c17 100644 --- a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala @@ -50,7 +50,7 @@ object Grep2 { println("Data count: " + data.count()) val sentences = new ConstantInputDStream(ssc, data) - ssc.inputStreams += sentences + ssc.registerInputStream(sentences) sentences.filter(_.contains("Culpepper")).count().foreachRDD(r => println("Grep count: " + r.collect().mkString)) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index c22949d7b9..8390f4af94 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -93,7 +93,7 @@ object WordCount2 { println("Data count: " + data.count()) val sentences = new ConstantInputDStream(ssc, data) - ssc.inputStreams += sentences + ssc.registerInputStream(sentences) import WordCount2_ExtraFunctions._ diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala index 3658cb302d..fc7567322b 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala @@ -50,7 +50,7 @@ object WordMax2 { println("Data count: " + data.count()) val sentences = new ConstantInputDStream(ssc, data) - ssc.inputStreams += sentences + ssc.registerInputStream(sentences) import WordCount2_ExtraFunctions._ diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala index 5da9fa6ecc..7f19b26a79 100644 --- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -17,12 +17,23 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => } var nextTime = 0L - - def start(): Long = { - nextTime = (math.floor(clock.currentTime / period) + 1).toLong * period - thread.start() + + def start(startTime: Long): Long = { + nextTime = startTime + thread.start() nextTime } + + def start(): Long = { + val startTime = math.ceil(clock.currentTime / period).toLong * period + start(startTime) + } + + def restart(originalStartTime: Long): Long = { + val gap = clock.currentTime - originalStartTime + val newStartTime = math.ceil(gap / period).toLong * period + originalStartTime + start(newStartTime) + } def stop() { thread.interrupt() -- cgit v1.2.3 From ac12abc17ff90ec99192f3c3de4d1d390969e635 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 29 Oct 2012 11:55:27 -0700 Subject: Modified RDD API to make dependencies a var (therefore can be changed to checkpointed hadoop rdd) and othere references to parent RDDs either through dependencies or through a weak reference (to allow finalizing when dependencies do not refer to it any more). --- core/src/main/scala/spark/PairRDDFunctions.scala | 24 ++++++++++----------- core/src/main/scala/spark/ParallelCollection.scala | 8 +++---- core/src/main/scala/spark/RDD.scala | 25 ++++++++++++++++------ core/src/main/scala/spark/SparkContext.scala | 4 ++++ core/src/main/scala/spark/rdd/CartesianRDD.scala | 21 ++++++++++++------ core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 13 +++++++---- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 19 ++++++++++------ core/src/main/scala/spark/rdd/FilteredRDD.scala | 12 +++++++---- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 10 ++++----- core/src/main/scala/spark/rdd/GlommedRDD.scala | 9 ++++---- core/src/main/scala/spark/rdd/HadoopRDD.scala | 4 +--- .../main/scala/spark/rdd/MapPartitionsRDD.scala | 12 +++++------ .../spark/rdd/MapPartitionsWithSplitRDD.scala | 10 ++++----- core/src/main/scala/spark/rdd/MappedRDD.scala | 12 +++++------ core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 9 ++++---- core/src/main/scala/spark/rdd/PipedRDD.scala | 16 +++++++------- core/src/main/scala/spark/rdd/SampledRDD.scala | 15 ++++++------- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 13 ++++++----- core/src/main/scala/spark/rdd/UnionRDD.scala | 20 ++++++++++------- 19 files changed, 149 insertions(+), 107 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index e5bb639cfd..f52af08125 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -23,6 +23,7 @@ import spark.partial.BoundedDouble import spark.partial.PartialResult import spark.rdd._ import spark.SparkContext._ +import java.lang.ref.WeakReference /** * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. @@ -624,23 +625,22 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( } private[spark] -class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override val partitioner = prev.partitioner - override def compute(split: Split) = prev.iterator(split).map{case (k, v) => (k, f(v))} +class MappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => U) + extends RDD[(K, U)](prev.get) { + + override def splits = firstParent[(K, V)].splits + override val partitioner = firstParent[(K, V)].partitioner + override def compute(split: Split) = firstParent[(K, V)].iterator(split).map{case (k, v) => (k, f(v))} } private[spark] -class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]) - extends RDD[(K, U)](prev.context) { - - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override val partitioner = prev.partitioner +class FlatMappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) + extends RDD[(K, U)](prev.get) { + override def splits = firstParent[(K, V)].splits + override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split) = { - prev.iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } + firstParent[(K, V)].iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } } } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9b57ae3b4f..ad06ee9736 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -22,13 +22,13 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - sc: SparkContext, + @transient sc_ : SparkContext, @transient data: Seq[T], numSlices: Int) - extends RDD[T](sc) { + extends RDD[T](sc_, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split - // instead. + // instead. UPDATE: With the new changes to enable checkpointing, this an be done. @transient val splits_ = { @@ -41,8 +41,6 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def preferredLocations(s: Split): Seq[String] = Nil - - override val dependencies: List[Dependency[_]] = Nil } private object ParallelCollection { diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 338dff4061..c9f3763f73 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -72,7 +72,14 @@ import SparkContext._ * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details * on RDD internals. */ -abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serializable { +abstract class RDD[T: ClassManifest]( + @transient var sc: SparkContext, + @transient var dependencies_ : List[Dependency[_]] = Nil + ) extends Serializable { + + + def this(@transient oneParent: RDD[_]) = + this(oneParent.context , List(new OneToOneDependency(oneParent))) // Methods that must be implemented by subclasses: @@ -83,10 +90,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def compute(split: Split): Iterator[T] /** How this RDD depends on any parent RDDs. */ - @transient val dependencies: List[Dependency[_]] + def dependencies: List[Dependency[_]] = dependencies_ + //var dependencies: List[Dependency[_]] = dependencies_ - // Methods available on all RDDs: - /** Record user function generating this RDD. */ private[spark] val origin = Utils.getSparkCallSite @@ -106,8 +112,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE - - /** + + private[spark] def firstParent[U: ClassManifest] = dependencies.head.rdd.asInstanceOf[RDD[U]] + private[spark] def parent[U: ClassManifest](id: Int) = dependencies(id).rdd.asInstanceOf[RDD[U]] + + // Methods available on all RDDs: + + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. */ @@ -129,7 +140,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - + private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { if (!level.useDisk && level.replication < 2) { throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 0d37075ef3..6b957a6356 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -3,6 +3,7 @@ package spark import java.io._ import java.util.concurrent.atomic.AtomicInteger import java.net.{URI, URLClassLoader} +import java.lang.ref.WeakReference import scala.collection.Map import scala.collection.generic.Growable @@ -695,6 +696,9 @@ object SparkContext { /** Find the JAR that contains the class of a particular object */ def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass) + + implicit def rddToWeakRefRDD[T: ClassManifest](rdd: RDD[T]) = new WeakReference(rdd) + } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 7c354b6b2e..c97b835630 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -4,6 +4,7 @@ import spark.NarrowDependency import spark.RDD import spark.SparkContext import spark.Split +import java.lang.ref.WeakReference private[spark] class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { @@ -13,13 +14,17 @@ class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with private[spark] class CartesianRDD[T: ClassManifest, U:ClassManifest]( sc: SparkContext, - rdd1: RDD[T], - rdd2: RDD[U]) + rdd1_ : WeakReference[RDD[T]], + rdd2_ : WeakReference[RDD[U]]) extends RDD[Pair[T, U]](sc) with Serializable { - + + def rdd1 = rdd1_.get + def rdd2 = rdd2_.get + val numSplitsInRdd2 = rdd2.splits.size - + + // TODO: make this null when finishing checkpoint @transient val splits_ = { // create the cross product split @@ -31,6 +36,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( array } + // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override def preferredLocations(split: Split) = { @@ -42,8 +48,9 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( val currSplit = split.asInstanceOf[CartesianSplit] for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y) } - - override val dependencies = List( + + // TODO: make this null when finishing checkpoint + var deps = List( new NarrowDependency(rdd1) { def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) }, @@ -51,4 +58,6 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2) } ) + + override def dependencies = deps } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 50bec9e63b..af54ac2fa0 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -31,12 +31,13 @@ private[spark] class CoGroupAggregator with Serializable class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) - extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging { + extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging { val aggr = new CoGroupAggregator - + + // TODO: make this null when finishing checkpoint @transient - override val dependencies = { + var deps = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) @@ -50,7 +51,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) } deps.toList } - + + override def dependencies = deps + + // TODO: make this null when finishing checkpoint @transient val splits_ : Array[Split] = { val firstRdd = rdds.head @@ -68,6 +72,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) array } + // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override val partitioner = Some(part) diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 0967f4f5df..573acf8893 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -14,11 +14,14 @@ private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) exten * This transformation is useful when an RDD with many partitions gets filtered into a smaller one, * or to avoid having a large number of small tasks when processing a directory with many files. */ -class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int) - extends RDD[T](prev.context) { +class CoalescedRDD[T: ClassManifest]( + @transient prev: RDD[T], // TODO: Make this a weak reference + maxPartitions: Int) + extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs + // TODO: make this null when finishing checkpoint @transient val splits_ : Array[Split] = { - val prevSplits = prev.splits + val prevSplits = firstParent[T].splits if (prevSplits.length < maxPartitions) { prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) } } else { @@ -30,18 +33,22 @@ class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int) } } + // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override def compute(split: Split): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { - parentSplit => prev.iterator(parentSplit) + parentSplit => firstParent[T].iterator(parentSplit) } } - val dependencies = List( - new NarrowDependency(prev) { + // TODO: make this null when finishing checkpoint + var deps = List( + new NarrowDependency(firstParent) { def getParents(id: Int): Seq[Int] = splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index) } ) + + override def dependencies = deps } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index dfe9dc73f3..cc2a3acd3a 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -3,10 +3,14 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] -class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).filter(f) +class FilteredRDD[T: ClassManifest]( + @transient prev: WeakReference[RDD[T]], + f: T => Boolean) + extends RDD[T](prev.get) { + + override def splits = firstParent[T].splits + override def compute(split: Split) = firstParent[T].iterator(split).filter(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 3534dc8057..34bd784c13 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -3,14 +3,14 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], f: T => TraversableOnce[U]) - extends RDD[U](prev.context) { + extends RDD[U](prev.get) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).flatMap(f) + override def splits = firstParent[T].splits + override def compute(split: Split) = firstParent[T].iterator(split).flatMap(f) } diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index e30564f2da..9321e89dcd 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -3,10 +3,11 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] -class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator +class GlommedRDD[T: ClassManifest](@transient prev: WeakReference[RDD[T]]) + extends RDD[Array[T]](prev.get) { + override def splits = firstParent[T].splits + override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index bf29a1f075..a12531ea89 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -46,7 +46,7 @@ class HadoopRDD[K, V]( keyClass: Class[K], valueClass: Class[V], minSplits: Int) - extends RDD[(K, V)](sc) { + extends RDD[(K, V)](sc, Nil) { // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it val confBroadcast = sc.broadcast(new SerializableWritable(conf)) @@ -115,6 +115,4 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } - - override val dependencies: List[Dependency[_]] = Nil } diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index a904ef62c3..bad872c430 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -3,17 +3,17 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false) - extends RDD[U](prev.context) { + extends RDD[U](prev.get) { - override val partitioner = if (preservesPartitioning) prev.partitioner else None + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = f(prev.iterator(split)) + override def splits = firstParent[T].splits + override def compute(split: Split) = f(firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index adc541694e..d7b238b05d 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -3,6 +3,7 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference /** * A variant of the MapPartitionsRDD that passes the split index into the @@ -11,11 +12,10 @@ import spark.Split */ private[spark] class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], f: (Int, Iterator[T]) => Iterator[U]) - extends RDD[U](prev.context) { + extends RDD[U](prev.get) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = f(split.index, prev.iterator(split)) + override def splits = firstParent[T].splits + override def compute(split: Split) = f(split.index, firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 59bedad8ef..126c6f332b 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -3,14 +3,14 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], f: T => U) - extends RDD[U](prev.context) { - - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).map(f) + extends RDD[U](prev.get) { + + override def splits = firstParent[T].splits + override def compute(split: Split) = firstParent[T].iterator(split).map(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index 7a1a0fb87d..c12df5839e 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -23,11 +23,12 @@ class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit } class NewHadoopRDD[K, V]( - sc: SparkContext, + sc : SparkContext, inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], valueClass: Class[V], + keyClass: Class[K], + valueClass: Class[V], @transient conf: Configuration) - extends RDD[(K, V)](sc) + extends RDD[(K, V)](sc, Nil) with HadoopMapReduceUtil { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it @@ -92,6 +93,4 @@ class NewHadoopRDD[K, V]( val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } - - override val dependencies: List[Dependency[_]] = Nil } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 98ea0c92d6..d54579d6d1 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -19,18 +19,18 @@ import spark.Split * (printing them one per line) and returns the output as a collection of strings. */ class PipedRDD[T: ClassManifest]( - parent: RDD[T], command: Seq[String], envVars: Map[String, String]) - extends RDD[String](parent.context) { + @transient prev: RDD[T], + command: Seq[String], + envVars: Map[String, String]) + extends RDD[String](prev) { - def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map()) + def this(@transient prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command)) + def this(@transient prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) - override def splits = parent.splits - - override val dependencies = List(new OneToOneDependency(parent)) + override def splits = firstParent[T].splits override def compute(split: Split): Iterator[String] = { val pb = new ProcessBuilder(command) @@ -55,7 +55,7 @@ class PipedRDD[T: ClassManifest]( override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) - for (elem <- parent.iterator(split)) { + for (elem <- firstParent[T].iterator(split)) { out.println(elem) } out.close() diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 87a5268f27..00b521b130 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -7,6 +7,7 @@ import cern.jet.random.engine.DRand import spark.RDD import spark.OneToOneDependency import spark.Split +import java.lang.ref.WeakReference private[spark] class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { @@ -14,24 +15,22 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali } class SampledRDD[T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], withReplacement: Boolean, frac: Double, seed: Int) - extends RDD[T](prev.context) { + extends RDD[T](prev.get) { @transient val splits_ = { val rg = new Random(seed) - prev.splits.map(x => new SampledRDDSplit(x, rg.nextInt)) + firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } override def splits = splits_.asInstanceOf[Array[Split]] - override val dependencies = List(new OneToOneDependency(prev)) - override def preferredLocations(split: Split) = - prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) + firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) override def compute(splitIn: Split) = { val split = splitIn.asInstanceOf[SampledRDDSplit] @@ -39,7 +38,7 @@ class SampledRDD[T: ClassManifest]( // For large datasets, the expected number of occurrences of each element in a sample with // replacement is Poisson(frac). We use that to get a count for each element. val poisson = new Poisson(frac, new DRand(split.seed)) - prev.iterator(split.prev).flatMap { element => + firstParent[T].iterator(split.prev).flatMap { element => val count = poisson.nextInt() if (count == 0) { Iterator.empty // Avoid object allocation when we return 0 items, which is quite often @@ -49,7 +48,7 @@ class SampledRDD[T: ClassManifest]( } } else { // Sampling without replacement val rand = new Random(split.seed) - prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac)) + firstParent[T].iterator(split.prev).filter(x => (rand.nextDouble <= frac)) } } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 145e419c53..62867dab4f 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -5,6 +5,7 @@ import spark.RDD import spark.ShuffleDependency import spark.SparkEnv import spark.Split +import java.lang.ref.WeakReference private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { override val index = idx @@ -19,8 +20,9 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { * @tparam V the value class. */ class ShuffledRDD[K, V]( - @transient parent: RDD[(K, V)], - part: Partitioner) extends RDD[(K, V)](parent.context) { + @transient prev: WeakReference[RDD[(K, V)]], + part: Partitioner) + extends RDD[(K, V)](prev.get.context, List(new ShuffleDependency(prev.get, part))) { override val partitioner = Some(part) @@ -31,10 +33,11 @@ class ShuffledRDD[K, V]( override def preferredLocations(split: Split) = Nil - val dep = new ShuffleDependency(parent, part) - override val dependencies = List(dep) + //val dep = new ShuffleDependency(parent, part) + //override val dependencies = List(dep) override def compute(split: Split): Iterator[(K, V)] = { - SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index) + val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId + SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index f0b9225f7c..0a61a2d1f5 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -7,6 +7,7 @@ import spark.RangeDependency import spark.RDD import spark.SparkContext import spark.Split +import java.lang.ref.WeakReference private[spark] class UnionSplit[T: ClassManifest]( idx: Int, @@ -22,10 +23,10 @@ private[spark] class UnionSplit[T: ClassManifest]( class UnionRDD[T: ClassManifest]( sc: SparkContext, - @transient rdds: Seq[RDD[T]]) - extends RDD[T](sc) - with Serializable { - + @transient rdds: Seq[RDD[T]]) // TODO: Make this a weak reference + extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs + + // TODO: make this null when finishing checkpoint @transient val splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) @@ -37,19 +38,22 @@ class UnionRDD[T: ClassManifest]( array } + // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ - @transient - override val dependencies = { + // TODO: make this null when finishing checkpoint + @transient var deps = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) + deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) pos += rdd.splits.size } deps.toList } - + + override def dependencies = deps + override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() override def preferredLocations(s: Split): Seq[String] = -- cgit v1.2.3 From 531ac136bf4ed333cb906ac229d986605a8207a6 Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 29 Oct 2012 14:53:47 -0700 Subject: BlockManager UI. --- core/src/main/scala/spark/RDD.scala | 8 ++ core/src/main/scala/spark/SparkContext.scala | 10 ++ .../scala/spark/storage/BlockManagerMaster.scala | 33 ++++++- .../main/scala/spark/storage/BlockManagerUI.scala | 102 +++++++++++++++++++++ core/src/main/scala/spark/util/AkkaUtils.scala | 5 +- core/src/main/twirl/spark/common/layout.scala.html | 35 +++++++ .../twirl/spark/deploy/common/layout.scala.html | 35 ------- .../twirl/spark/deploy/master/index.scala.html | 2 +- .../spark/deploy/master/job_details.scala.html | 2 +- .../twirl/spark/deploy/worker/index.scala.html | 2 +- core/src/main/twirl/spark/storage/index.scala.html | 28 ++++++ core/src/main/twirl/spark/storage/rdd.scala.html | 65 +++++++++++++ .../main/twirl/spark/storage/rdd_row.scala.html | 18 ++++ .../main/twirl/spark/storage/rdd_table.scala.html | 18 ++++ 14 files changed, 318 insertions(+), 45 deletions(-) create mode 100644 core/src/main/scala/spark/storage/BlockManagerUI.scala create mode 100644 core/src/main/twirl/spark/common/layout.scala.html delete mode 100644 core/src/main/twirl/spark/deploy/common/layout.scala.html create mode 100644 core/src/main/twirl/spark/storage/index.scala.html create mode 100644 core/src/main/twirl/spark/storage/rdd.scala.html create mode 100644 core/src/main/twirl/spark/storage/rdd_row.scala.html create mode 100644 core/src/main/twirl/spark/storage/rdd_table.scala.html (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 338dff4061..dc757dc6aa 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -107,6 +107,12 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE + /* Assign a name to this RDD */ + def name(name: String) = { + sc.rddNames(this.id) = name + this + } + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. @@ -118,6 +124,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial "Cannot change storage level of an RDD after it was already assigned a level") } storageLevel = newLevel + // Register the RDD with the SparkContext + sc.persistentRdds(id) = this this } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d26cccbfe1..71c9dcd017 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -1,6 +1,7 @@ package spark import java.io._ +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger import java.net.{URI, URLClassLoader} @@ -102,10 +103,19 @@ class SparkContext( isLocal) SparkEnv.set(env) + // Start the BlockManager UI + spark.storage.BlockManagerUI.start(SparkEnv.get.actorSystem, + SparkEnv.get.blockManager.master.masterActor, this) + // Used to store a URL for each static file/jar together with the file's local timestamp private[spark] val addedFiles = HashMap[String, Long]() private[spark] val addedJars = HashMap[String, Long]() + // Keeps track of all persisted RDDs + private[spark] val persistentRdds = new ConcurrentHashMap[Int, RDD[_]]() + // A HashMap for friendly RDD Names + private[spark] val rddNames = new ConcurrentHashMap[Int, String]() + // Add each JAR given through the constructor jars.foreach { addJar(_) } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index ace27e758c..d12a16869a 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -3,7 +3,8 @@ package spark.storage import java.io._ import java.util.{HashMap => JHashMap} -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.util.Random import akka.actor._ @@ -90,6 +91,15 @@ case object StopBlockManagerMaster extends ToBlockManagerMaster private[spark] case object GetMemoryStatus extends ToBlockManagerMaster +private[spark] +case class GetStorageStatus extends ToBlockManagerMaster + +private[spark] +case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) + +private[spark] +case class StorageStatus(maxMem: Long, remainingMem: Long, blocks: Map[String, BlockStatus]) + private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { @@ -99,7 +109,8 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor val maxMem: Long) { private var _lastSeenMs = timeMs private var _remainingMem = maxMem - private val _blocks = new JHashMap[String, StorageLevel] + + private val _blocks = new JHashMap[String, BlockStatus] logInfo("Registering block manager %s:%d with %s RAM".format( blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) @@ -115,7 +126,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (_blocks.containsKey(blockId)) { // The block exists on the slave already. - val originalLevel: StorageLevel = _blocks.get(blockId) + val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel if (originalLevel.useMemory) { _remainingMem += memSize @@ -124,7 +135,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (storageLevel.isValid) { // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, storageLevel) + _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) if (storageLevel.useMemory) { _remainingMem -= memSize logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( @@ -137,7 +148,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor } } else if (_blocks.containsKey(blockId)) { // If isValid is not true, drop the block. - val originalLevel: StorageLevel = _blocks.get(blockId) + val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel _blocks.remove(blockId) if (originalLevel.useMemory) { _remainingMem += memSize @@ -152,6 +163,8 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor } } + def blocks: JHashMap[String, BlockStatus] = _blocks + def remainingMem: Long = _remainingMem def lastSeenMs: Long = _lastSeenMs @@ -198,6 +211,9 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor case GetMemoryStatus => getMemoryStatus + case GetStorageStatus => + getStorageStatus + case RemoveHost(host) => removeHost(host) sender ! true @@ -219,6 +235,13 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor sender ! res } + private def getStorageStatus() { + val res = blockManagerInfo.map { case(blockManagerId, info) => + StorageStatus(info.maxMem, info.remainingMem, info.blocks.asScala) + } + sender ! res + } + private def register(blockManagerId: BlockManagerId, maxMemSize: Long) { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockManagerId + " " diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala new file mode 100644 index 0000000000..c168f60c35 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -0,0 +1,102 @@ +package spark.storage + +import akka.actor.{ActorRef, ActorSystem} +import akka.dispatch.Await +import akka.pattern.ask +import akka.util.Timeout +import akka.util.duration._ +import cc.spray.Directives +import cc.spray.directives._ +import cc.spray.typeconversion.TwirlSupport._ +import scala.collection.mutable.ArrayBuffer +import spark.{Logging, SparkContext, SparkEnv} +import spark.util.AkkaUtils + +private[spark] +object BlockManagerUI extends Logging { + + /* Starts the Web interface for the BlockManager */ + def start(actorSystem : ActorSystem, masterActor: ActorRef, sc: SparkContext) { + val webUIDirectives = new BlockManagerUIDirectives(actorSystem, masterActor, sc) + try { + logInfo("Starting BlockManager WebUI.") + val port = Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt + AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, webUIDirectives.handler, "BlockManagerHTTPServer") + } catch { + case e: Exception => + logError("Failed to create BlockManager WebUI", e) + System.exit(1) + } + } + +} + +private[spark] +case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, numPartitions: Int, memSize: Long, diskSize: Long) + +private[spark] +class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, sc: SparkContext) extends Directives { + + val STATIC_RESOURCE_DIR = "spark/deploy/static" + implicit val timeout = Timeout(1 seconds) + + val handler = { + + get { path("") { completeWith { + // Request the current storage status from the Master + val future = master ? GetStorageStatus + future.map { status => + val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] + + // Calculate macro-level statistics + val maxMem = storageStati.map(_.maxMem).reduce(_+_) + val remainingMem = storageStati.map(_.remainingMem).reduce(_+_) + val diskSpaceUsed = storageStati.flatMap(_.blocks.values.map(_.diskSize)) + .reduceOption(_+_).getOrElse(0L) + + // Filter out everything that's not and rdd. + val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => k.startsWith("rdd") }.toMap + val rdds = rddInfoFromBlockStati(rddBlocks) + + spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds.toList) + } + }}} ~ + get { path("rdd") { parameter("id") { id => { completeWith { + val future = master ? GetStorageStatus + future.map { status => + val prefix = "rdd_" + id.toString + + val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] + val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => k.startsWith(prefix) }.toMap + val rddInfo = rddInfoFromBlockStati(rddBlocks).first + + spark.storage.html.rdd.render(rddInfo, rddBlocks) + + } + }}}}} ~ + pathPrefix("static") { + getFromResourceDirectory(STATIC_RESOURCE_DIR) + } + + } + + private def rddInfoFromBlockStati(infos: Map[String, BlockStatus]) : Array[RDDInfo] = { + infos.groupBy { case(k,v) => + // Group by rdd name, ignore the partition name + k.substring(0,k.lastIndexOf('_')) + }.map { case(k,v) => + val blockStati = v.map(_._2).toArray + // Add up memory and disk sizes + val tmp = blockStati.map { x => (x.memSize, x.diskSize)}.reduce { (x,y) => + (x._1 + y._1, x._2 + y._2) + } + // Get the friendly name for the rdd, if available. + // This is pretty hacky, is there a better way? + val rddId = k.split("_").last.toInt + val rddName : String = Option(sc.rddNames.get(rddId)).getOrElse(k) + val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel + RDDInfo(rddId, rddName, rddStorageLevel, blockStati.length, tmp._1, tmp._2) + }.toArray + } + +} diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index b466b5239c..13bc0f8ccc 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -50,12 +50,13 @@ private[spark] object AkkaUtils { * Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to * handle requests. Throws a SparkException if this fails. */ - def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route) { + def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, + name: String = "HttpServer") { val ioWorker = new IoWorker(actorSystem).start() val httpService = actorSystem.actorOf(Props(new HttpService(route))) val rootService = actorSystem.actorOf(Props(new SprayCanRootService(httpService))) val server = actorSystem.actorOf( - Props(new HttpServer(ioWorker, SingletonHandler(rootService))), name = "HttpServer") + Props(new HttpServer(ioWorker, SingletonHandler(rootService))), name = name) actorSystem.registerOnTermination { ioWorker.stop() } val timeout = 3.seconds val future = server.ask(HttpServer.Bind(ip, port))(timeout) diff --git a/core/src/main/twirl/spark/common/layout.scala.html b/core/src/main/twirl/spark/common/layout.scala.html new file mode 100644 index 0000000000..b9192060aa --- /dev/null +++ b/core/src/main/twirl/spark/common/layout.scala.html @@ -0,0 +1,35 @@ +@(title: String)(content: Html) + + + + + + + + + + @title + + + + +
+ + +
+
+ +

@title

+
+
+ +
+ + @content + +
+ + + \ No newline at end of file diff --git a/core/src/main/twirl/spark/deploy/common/layout.scala.html b/core/src/main/twirl/spark/deploy/common/layout.scala.html deleted file mode 100644 index b9192060aa..0000000000 --- a/core/src/main/twirl/spark/deploy/common/layout.scala.html +++ /dev/null @@ -1,35 +0,0 @@ -@(title: String)(content: Html) - - - - - - - - - - @title - - - - -
- - -
-
- -

@title

-
-
- -
- - @content - -
- - - \ No newline at end of file diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html index 7562076b00..2e15fe2200 100644 --- a/core/src/main/twirl/spark/deploy/master/index.scala.html +++ b/core/src/main/twirl/spark/deploy/master/index.scala.html @@ -1,7 +1,7 @@ @(state: spark.deploy.MasterState) @import spark.deploy.master._ -@spark.deploy.common.html.layout(title = "Spark Master on " + state.uri) { +@spark.common.html.layout(title = "Spark Master on " + state.uri) {
diff --git a/core/src/main/twirl/spark/deploy/master/job_details.scala.html b/core/src/main/twirl/spark/deploy/master/job_details.scala.html index dcf41c28f2..d02a51b214 100644 --- a/core/src/main/twirl/spark/deploy/master/job_details.scala.html +++ b/core/src/main/twirl/spark/deploy/master/job_details.scala.html @@ -1,6 +1,6 @@ @(job: spark.deploy.master.JobInfo) -@spark.deploy.common.html.layout(title = "Job Details") { +@spark.common.html.layout(title = "Job Details") {
diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html index 69746ed02c..40c2d81d77 100644 --- a/core/src/main/twirl/spark/deploy/worker/index.scala.html +++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html @@ -1,6 +1,6 @@ @(worker: spark.deploy.WorkerState) -@spark.deploy.common.html.layout(title = "Spark Worker on " + worker.uri) { +@spark.common.html.layout(title = "Spark Worker on " + worker.uri) {
diff --git a/core/src/main/twirl/spark/storage/index.scala.html b/core/src/main/twirl/spark/storage/index.scala.html new file mode 100644 index 0000000000..fa7dad51ee --- /dev/null +++ b/core/src/main/twirl/spark/storage/index.scala.html @@ -0,0 +1,28 @@ +@(maxMem: Long, remainingMem: Long, diskSpaceUsed: Long, rdds: List[spark.storage.RDDInfo]) + +@spark.common.html.layout(title = "Storage Dashboard") { + + +
+
+
    +
  • Memory: + @{spark.Utils.memoryBytesToString(maxMem - remainingMem)} Used + (@{spark.Utils.memoryBytesToString(remainingMem)} Available)
  • +
  • Disk: @{spark.Utils.memoryBytesToString(diskSpaceUsed)} Used
  • +
+
+
+ +
+ + +
+
+

RDD Summary

+
+ @rdd_table(rdds) +
+
+ +} \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html new file mode 100644 index 0000000000..3a70326efe --- /dev/null +++ b/core/src/main/twirl/spark/storage/rdd.scala.html @@ -0,0 +1,65 @@ +@(rddInfo: spark.storage.RDDInfo, blocks: Map[String, spark.storage.BlockStatus]) + +@spark.common.html.layout(title = "RDD Info ") { + + +
+
+
    +
  • + Storage Level: + @(if (rddInfo.storageLevel.useDisk) "Disk" else "") + @(if (rddInfo.storageLevel.useMemory) "Memory" else "") + @(if (rddInfo.storageLevel.deserialized) "Deserialized" else "") + @(rddInfo.storageLevel.replication)x Replicated +
  • + Partitions: + @(rddInfo.numPartitions) +
  • +
  • + Memory Size: + @{spark.Utils.memoryBytesToString(rddInfo.memSize)} +
  • +
  • + Disk Size: + @{spark.Utils.memoryBytesToString(rddInfo.diskSize)} +
  • +
+
+
+ +
+ + +
+
+

RDD Summary

+
+ + + + + + + + + + + + + @blocks.map { case (k,v) => + + + + + + + } + +
Block NameStorage LevelSize in MemorySize on Disk
@k@v.storageLevel@{spark.Utils.memoryBytesToString(v.memSize)}@{spark.Utils.memoryBytesToString(v.diskSize)}
+ + +
+
+ +} \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd_row.scala.html b/core/src/main/twirl/spark/storage/rdd_row.scala.html new file mode 100644 index 0000000000..3dd9944e3b --- /dev/null +++ b/core/src/main/twirl/spark/storage/rdd_row.scala.html @@ -0,0 +1,18 @@ +@(rdd: spark.storage.RDDInfo) + + + + + @rdd.name + + + + @(if (rdd.storageLevel.useDisk) "Disk" else "") + @(if (rdd.storageLevel.useMemory) "Memory" else "") + @(if (rdd.storageLevel.deserialized) "Deserialized" else "") + @(rdd.storageLevel.replication)x Replicated + + @rdd.numPartitions + @{spark.Utils.memoryBytesToString(rdd.memSize)} + @{spark.Utils.memoryBytesToString(rdd.diskSize)} + \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd_table.scala.html b/core/src/main/twirl/spark/storage/rdd_table.scala.html new file mode 100644 index 0000000000..24f55ccefb --- /dev/null +++ b/core/src/main/twirl/spark/storage/rdd_table.scala.html @@ -0,0 +1,18 @@ +@(rdds: List[spark.storage.RDDInfo]) + + + + + + + + + + + + + @for(rdd <- rdds) { + @rdd_row(rdd) + } + +
RDD NameStorage LevelPartitionsSize in MemorySize on Disk
\ No newline at end of file -- cgit v1.2.3 From eb95212f4d24dbcd734922f39d51e6fdeaeb4c8b Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 29 Oct 2012 14:57:32 -0700 Subject: code Formatting --- core/src/main/scala/spark/storage/BlockManagerUI.scala | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index c168f60c35..635c096c87 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -21,7 +21,8 @@ object BlockManagerUI extends Logging { try { logInfo("Starting BlockManager WebUI.") val port = Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt - AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, webUIDirectives.handler, "BlockManagerHTTPServer") + AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, + webUIDirectives.handler, "BlockManagerHTTPServer") } catch { case e: Exception => logError("Failed to create BlockManager WebUI", e) @@ -32,10 +33,12 @@ object BlockManagerUI extends Logging { } private[spark] -case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, numPartitions: Int, memSize: Long, diskSize: Long) +case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, + numPartitions: Int, memSize: Long, diskSize: Long) private[spark] -class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, sc: SparkContext) extends Directives { +class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, + sc: SparkContext) extends Directives { val STATIC_RESOURCE_DIR = "spark/deploy/static" implicit val timeout = Timeout(1 seconds) @@ -55,7 +58,9 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, s .reduceOption(_+_).getOrElse(0L) // Filter out everything that's not and rdd. - val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => k.startsWith("rdd") }.toMap + val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => + k.startsWith("rdd") + }.toMap val rdds = rddInfoFromBlockStati(rddBlocks) spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds.toList) @@ -67,7 +72,9 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, s val prefix = "rdd_" + id.toString val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] - val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => k.startsWith(prefix) }.toMap + val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => + k.startsWith(prefix) + }.toMap val rddInfo = rddInfoFromBlockStati(rddBlocks).first spark.storage.html.rdd.render(rddInfo, rddBlocks) -- cgit v1.2.3 From ceec1a1a6abb1fd03316e7fcc532d7e121d5bf65 Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 29 Oct 2012 15:03:01 -0700 Subject: Nicer storage level format on RDD page --- core/src/main/twirl/spark/storage/rdd.scala.html | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html index 3a70326efe..075289c826 100644 --- a/core/src/main/twirl/spark/storage/rdd.scala.html +++ b/core/src/main/twirl/spark/storage/rdd.scala.html @@ -50,7 +50,12 @@ @blocks.map { case (k,v) => @k - @v.storageLevel + + @(if (v.storageLevel.useDisk) "Disk" else "") + @(if (v.storageLevel.useMemory) "Memory" else "") + @(if (v.storageLevel.deserialized) "Deserialized" else "") + @(v.storageLevel.replication)x Replicated + @{spark.Utils.memoryBytesToString(v.memSize)} @{spark.Utils.memoryBytesToString(v.diskSize)} -- cgit v1.2.3 From 0dcd770fdc4d558972b635b6770ed0120280ef22 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 30 Oct 2012 16:09:37 -0700 Subject: Added checkpointing support to all RDDs, along with CheckpointSuite to test checkpointing in them. --- core/src/main/scala/spark/PairRDDFunctions.scala | 4 +- core/src/main/scala/spark/ParallelCollection.scala | 4 +- core/src/main/scala/spark/RDD.scala | 129 +++++++++++++++++---- core/src/main/scala/spark/SparkContext.scala | 21 ++++ core/src/main/scala/spark/rdd/BlockRDD.scala | 13 ++- core/src/main/scala/spark/rdd/CartesianRDD.scala | 38 +++--- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 19 +-- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 26 +++-- core/src/main/scala/spark/rdd/FilteredRDD.scala | 2 +- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 2 +- core/src/main/scala/spark/rdd/GlommedRDD.scala | 2 +- core/src/main/scala/spark/rdd/HadoopRDD.scala | 2 + .../main/scala/spark/rdd/MapPartitionsRDD.scala | 2 +- .../spark/rdd/MapPartitionsWithSplitRDD.scala | 2 +- core/src/main/scala/spark/rdd/MappedRDD.scala | 2 +- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 2 + core/src/main/scala/spark/rdd/PipedRDD.scala | 9 +- core/src/main/scala/spark/rdd/SampledRDD.scala | 2 +- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 5 - core/src/main/scala/spark/rdd/UnionRDD.scala | 32 ++--- core/src/test/scala/spark/CheckpointSuite.scala | 116 ++++++++++++++++++ core/src/test/scala/spark/RDDSuite.scala | 25 +++- 22 files changed, 352 insertions(+), 107 deletions(-) create mode 100644 core/src/test/scala/spark/CheckpointSuite.scala (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index f52af08125..1f82bd3ab8 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -625,7 +625,7 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( } private[spark] -class MappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => U) +class MappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => U) extends RDD[(K, U)](prev.get) { override def splits = firstParent[(K, V)].splits @@ -634,7 +634,7 @@ class MappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V } private[spark] -class FlatMappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) +class FlatMappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) extends RDD[(K, U)](prev.get) { override def splits = firstParent[(K, V)].splits diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index ad06ee9736..9725017b61 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -22,10 +22,10 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - @transient sc_ : SparkContext, + @transient sc : SparkContext, @transient data: Seq[T], numSlices: Int) - extends RDD[T](sc_, Nil) { + extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split // instead. UPDATE: With the new changes to enable checkpointing, this an be done. diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index c9f3763f73..e272a0ede9 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -13,6 +13,7 @@ import scala.collection.Map import scala.collection.mutable.HashMap import scala.collection.JavaConversions.mapAsScalaMap +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text @@ -74,7 +75,7 @@ import SparkContext._ */ abstract class RDD[T: ClassManifest]( @transient var sc: SparkContext, - @transient var dependencies_ : List[Dependency[_]] = Nil + var dependencies_ : List[Dependency[_]] ) extends Serializable { @@ -91,7 +92,6 @@ abstract class RDD[T: ClassManifest]( /** How this RDD depends on any parent RDDs. */ def dependencies: List[Dependency[_]] = dependencies_ - //var dependencies: List[Dependency[_]] = dependencies_ /** Record user function generating this RDD. */ private[spark] val origin = Utils.getSparkCallSite @@ -100,7 +100,13 @@ abstract class RDD[T: ClassManifest]( val partitioner: Option[Partitioner] = None /** Optionally overridden by subclasses to specify placement preferences. */ - def preferredLocations(split: Split): Seq[String] = Nil + def preferredLocations(split: Split): Seq[String] = { + if (isCheckpointed) { + checkpointRDD.preferredLocations(split) + } else { + Nil + } + } /** The [[spark.SparkContext]] that this RDD was created on. */ def context = sc @@ -113,8 +119,23 @@ abstract class RDD[T: ClassManifest]( // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE - private[spark] def firstParent[U: ClassManifest] = dependencies.head.rdd.asInstanceOf[RDD[U]] - private[spark] def parent[U: ClassManifest](id: Int) = dependencies(id).rdd.asInstanceOf[RDD[U]] + /** Returns the first parent RDD */ + private[spark] def firstParent[U: ClassManifest] = { + dependencies.head.rdd.asInstanceOf[RDD[U]] + } + + /** Returns the `i` th parent RDD */ + private[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] + + // Variables relating to checkpointing + val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD + var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing + var isCheckpointInProgress = false // set to true when checkpointing is in progress + var isCheckpointed = false // set to true after checkpointing is completed + + var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed + var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file + var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD // Methods available on all RDDs: @@ -141,32 +162,94 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { - if (!level.useDisk && level.replication < 2) { - throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") - } - - // This is a hack. Ideally this should re-use the code used by the CacheTracker - // to generate the key. - def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index) - - persist(level) - sc.runJob(this, (iter: Iterator[T]) => {} ) - - val p = this.partitioner - - new BlockRDD[T](sc, splits.map(getSplitKey).toArray) { - override val partitioner = p + /** + * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` + * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. + * This is used to truncate very long lineages. In the current implementation, Spark will save + * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. + * Hence, it is strongly recommended to use checkpoint() on RDDs when + * (i) Checkpoint() is called before the any job has been executed on this RDD. + * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will + * require recomputation. + */ + protected[spark] def checkpoint() { + synchronized { + if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) { + // do nothing + } else if (isCheckpointable) { + shouldCheckpoint = true + } else { + throw new Exception(this + " cannot be checkpointed") + } } } - + + /** + * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job + * using this RDD has completed (therefore the RDD has been materialized and + * potentially stored in memory). In case this RDD is not marked for checkpointing, + * doCheckpoint() is called recursively on the parent RDDs. + */ + private[spark] def doCheckpoint() { + val startCheckpoint = synchronized { + if (isCheckpointable && shouldCheckpoint && !isCheckpointInProgress) { + isCheckpointInProgress = true + true + } else { + false + } + } + + if (startCheckpoint) { + val rdd = this + val env = SparkEnv.get + + // Spawn a new thread to do the checkpoint as it takes sometime to write the RDD to file + val th = new Thread() { + override def run() { + // Save the RDD to a file, create a new HadoopRDD from it, + // and change the dependencies from the original parents to the new RDD + SparkEnv.set(env) + rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString + rdd.saveAsObjectFile(checkpointFile) + rdd.synchronized { + rdd.checkpointRDD = context.objectFile[T](checkpointFile) + rdd.checkpointRDDSplits = rdd.checkpointRDD.splits + rdd.changeDependencies(rdd.checkpointRDD) + rdd.shouldCheckpoint = false + rdd.isCheckpointInProgress = false + rdd.isCheckpointed = true + } + } + } + th.start() + } else { + // Recursively call doCheckpoint() to perform checkpointing on parent RDD if they are marked + dependencies.foreach(_.rdd.doCheckpoint()) + } + } + + /** + * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]] + * (`newRDD`) created from the checkpoint file. This method must ensure that all references + * to the original parent RDDs must be removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own changing + * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + */ + protected def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD)) + } + /** * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. * This should ''not'' be called by users directly, but is available for implementors of custom * subclasses of RDD. */ final def iterator(split: Split): Iterator[T] = { - if (storageLevel != StorageLevel.NONE) { + if (isCheckpointed) { + // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original + checkpointRDD.iterator(checkpointRDDSplits(split.index)) + } else if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { compute(split) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 6b957a6356..79ceab5f4f 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -188,6 +188,8 @@ class SparkContext( private var dagScheduler = new DAGScheduler(taskScheduler) + private[spark] var checkpointDir: String = null + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ @@ -519,6 +521,7 @@ class SparkContext( val start = System.nanoTime val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") + rdd.doCheckpoint() result } @@ -575,6 +578,24 @@ class SparkContext( return f } + /** + * Set the directory under which RDDs are going to be checkpointed. This method will + * create this directory and will throw an exception of the path already exists (to avoid + * overwriting existing files may be overwritten). The directory will be deleted on exit + * if indicated. + */ + def setCheckpointDir(dir: String, deleteOnExit: Boolean = false) { + val path = new Path(dir) + val fs = path.getFileSystem(new Configuration()) + if (fs.exists(path)) { + throw new Exception("Checkpoint directory '" + path + "' already exists.") + } else { + fs.mkdirs(path) + if (deleteOnExit) fs.deleteOnExit(path) + } + checkpointDir = dir + } + /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */ def defaultParallelism: Int = taskScheduler.defaultParallelism diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index cb73976aed..f4c3f99011 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -14,7 +14,7 @@ private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split private[spark] class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) - extends RDD[T](sc) { + extends RDD[T](sc, Nil) { @transient val splits_ = (0 until blockIds.size).map(i => { @@ -41,9 +41,12 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St } } - override def preferredLocations(split: Split) = - locations_(split.asInstanceOf[BlockRDDSplit].blockId) - - override val dependencies: List[Dependency[_]] = Nil + override def preferredLocations(split: Split) = { + if (isCheckpointed) { + checkpointRDD.preferredLocations(split) + } else { + locations_(split.asInstanceOf[BlockRDDSplit].blockId) + } + } } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index c97b835630..458ad38d55 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,9 +1,6 @@ package spark.rdd -import spark.NarrowDependency -import spark.RDD -import spark.SparkContext -import spark.Split +import spark._ import java.lang.ref.WeakReference private[spark] @@ -14,19 +11,15 @@ class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with private[spark] class CartesianRDD[T: ClassManifest, U:ClassManifest]( sc: SparkContext, - rdd1_ : WeakReference[RDD[T]], - rdd2_ : WeakReference[RDD[U]]) - extends RDD[Pair[T, U]](sc) + var rdd1 : RDD[T], + var rdd2 : RDD[U]) + extends RDD[Pair[T, U]](sc, Nil) with Serializable { - def rdd1 = rdd1_.get - def rdd2 = rdd2_.get - val numSplitsInRdd2 = rdd2.splits.size - // TODO: make this null when finishing checkpoint @transient - val splits_ = { + var splits_ = { // create the cross product split val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { @@ -36,12 +29,15 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( array } - // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override def preferredLocations(split: Split) = { - val currSplit = split.asInstanceOf[CartesianSplit] - rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) + if (isCheckpointed) { + checkpointRDD.preferredLocations(split) + } else { + val currSplit = split.asInstanceOf[CartesianSplit] + rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) + } } override def compute(split: Split) = { @@ -49,8 +45,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y) } - // TODO: make this null when finishing checkpoint - var deps = List( + var deps_ = List( new NarrowDependency(rdd1) { def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) }, @@ -59,5 +54,12 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( } ) - override def dependencies = deps + override def dependencies = deps_ + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + rdd1 = null + rdd2 = null + } } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index af54ac2fa0..a313ebcbe8 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -30,14 +30,13 @@ private[spark] class CoGroupAggregator { (b1, b2) => b1 ++ b2 }) with Serializable -class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) +class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging { val aggr = new CoGroupAggregator - // TODO: make this null when finishing checkpoint @transient - var deps = { + var deps_ = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) @@ -52,11 +51,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) deps.toList } - override def dependencies = deps + override def dependencies = deps_ - // TODO: make this null when finishing checkpoint @transient - val splits_ : Array[Split] = { + var splits_ : Array[Split] = { val firstRdd = rdds.head val array = new Array[Split](part.numPartitions) for (i <- 0 until array.size) { @@ -72,13 +70,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) array } - // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override val partitioner = Some(part) - override def preferredLocations(s: Split) = Nil - override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size @@ -106,4 +101,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) } map.iterator } + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + rdds = null + } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 573acf8893..5b5f72ddeb 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -1,8 +1,7 @@ package spark.rdd -import spark.NarrowDependency -import spark.RDD -import spark.Split +import spark._ +import java.lang.ref.WeakReference private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split @@ -15,13 +14,12 @@ private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) exten * or to avoid having a large number of small tasks when processing a directory with many files. */ class CoalescedRDD[T: ClassManifest]( - @transient prev: RDD[T], // TODO: Make this a weak reference + var prev: RDD[T], maxPartitions: Int) extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs - // TODO: make this null when finishing checkpoint - @transient val splits_ : Array[Split] = { - val prevSplits = firstParent[T].splits + @transient var splits_ : Array[Split] = { + val prevSplits = prev.splits if (prevSplits.length < maxPartitions) { prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) } } else { @@ -33,7 +31,6 @@ class CoalescedRDD[T: ClassManifest]( } } - // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override def compute(split: Split): Iterator[T] = { @@ -42,13 +39,18 @@ class CoalescedRDD[T: ClassManifest]( } } - // TODO: make this null when finishing checkpoint - var deps = List( - new NarrowDependency(firstParent) { + var deps_ : List[Dependency[_]] = List( + new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index) } ) - override def dependencies = deps + override def dependencies = deps_ + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD)) + splits_ = newRDD.splits + prev = null + } } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index cc2a3acd3a..1370cf6faf 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -7,7 +7,7 @@ import java.lang.ref.WeakReference private[spark] class FilteredRDD[T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: T => Boolean) extends RDD[T](prev.get) { diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 34bd784c13..6b2cc67568 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -7,7 +7,7 @@ import java.lang.ref.WeakReference private[spark] class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: T => TraversableOnce[U]) extends RDD[U](prev.get) { diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index 9321e89dcd..0f0b6ab0ff 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -6,7 +6,7 @@ import spark.Split import java.lang.ref.WeakReference private[spark] -class GlommedRDD[T: ClassManifest](@transient prev: WeakReference[RDD[T]]) +class GlommedRDD[T: ClassManifest](prev: WeakReference[RDD[T]]) extends RDD[Array[T]](prev.get) { override def splits = firstParent[T].splits override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index a12531ea89..19ed56d9c0 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -115,4 +115,6 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } + + override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index bad872c430..b04f56cfcc 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -7,7 +7,7 @@ import java.lang.ref.WeakReference private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false) extends RDD[U](prev.get) { diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index d7b238b05d..7a4b6ffb03 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -12,7 +12,7 @@ import java.lang.ref.WeakReference */ private[spark] class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: (Int, Iterator[T]) => Iterator[U]) extends RDD[U](prev.get) { diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 126c6f332b..8fa1872e0a 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -7,7 +7,7 @@ import java.lang.ref.WeakReference private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: T => U) extends RDD[U](prev.get) { diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index c12df5839e..2875abb2db 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -93,4 +93,6 @@ class NewHadoopRDD[K, V]( val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } + + override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index d54579d6d1..d9293a9d1a 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -12,6 +12,7 @@ import spark.OneToOneDependency import spark.RDD import spark.SparkEnv import spark.Split +import java.lang.ref.WeakReference /** @@ -19,16 +20,16 @@ import spark.Split * (printing them one per line) and returns the output as a collection of strings. */ class PipedRDD[T: ClassManifest]( - @transient prev: RDD[T], + prev: WeakReference[RDD[T]], command: Seq[String], envVars: Map[String, String]) - extends RDD[String](prev) { + extends RDD[String](prev.get) { - def this(@transient prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) + def this(prev: WeakReference[RDD[T]], command: Seq[String]) = this(prev, command, Map()) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(@transient prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) + def this(prev: WeakReference[RDD[T]], command: String) = this(prev, PipedRDD.tokenize(command)) override def splits = firstParent[T].splits diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 00b521b130..f273f257f8 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -15,7 +15,7 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali } class SampledRDD[T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], withReplacement: Boolean, frac: Double, seed: Int) diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 62867dab4f..b7d843c26d 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -31,11 +31,6 @@ class ShuffledRDD[K, V]( override def splits = splits_ - override def preferredLocations(split: Split) = Nil - - //val dep = new ShuffleDependency(parent, part) - //override val dependencies = List(dep) - override def compute(split: Split): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 0a61a2d1f5..643a174160 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -2,11 +2,7 @@ package spark.rdd import scala.collection.mutable.ArrayBuffer -import spark.Dependency -import spark.RangeDependency -import spark.RDD -import spark.SparkContext -import spark.Split +import spark._ import java.lang.ref.WeakReference private[spark] class UnionSplit[T: ClassManifest]( @@ -23,12 +19,11 @@ private[spark] class UnionSplit[T: ClassManifest]( class UnionRDD[T: ClassManifest]( sc: SparkContext, - @transient rdds: Seq[RDD[T]]) // TODO: Make this a weak reference + @transient var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs - // TODO: make this null when finishing checkpoint @transient - val splits_ : Array[Split] = { + var splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { @@ -38,11 +33,9 @@ class UnionRDD[T: ClassManifest]( array } - // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ - // TODO: make this null when finishing checkpoint - @transient var deps = { + @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { @@ -52,10 +45,21 @@ class UnionRDD[T: ClassManifest]( deps.toList } - override def dependencies = deps + override def dependencies = deps_ override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() - override def preferredLocations(s: Split): Seq[String] = - s.asInstanceOf[UnionSplit[T]].preferredLocations() + override def preferredLocations(s: Split): Seq[String] = { + if (isCheckpointed) { + checkpointRDD.preferredLocations(s) + } else { + s.asInstanceOf[UnionSplit[T]].preferredLocations() + } + } + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD)) + splits_ = newRDD.splits + rdds = null + } } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala new file mode 100644 index 0000000000..0e5ca7dc21 --- /dev/null +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -0,0 +1,116 @@ +package spark + +import org.scalatest.{BeforeAndAfter, FunSuite} +import java.io.File +import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD} +import spark.SparkContext._ +import storage.StorageLevel + +class CheckpointSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + var checkpointDir: File = _ + + before { + checkpointDir = File.createTempFile("temp", "") + checkpointDir.delete() + + sc = new SparkContext("local", "test") + sc.setCheckpointDir(checkpointDir.toString) + } + + after { + if (sc != null) { + sc.stop() + sc = null + } + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + + if (checkpointDir != null) { + checkpointDir.delete() + } + } + + test("ParallelCollection") { + val parCollection = sc.makeRDD(1 to 4) + parCollection.checkpoint() + assert(parCollection.dependencies === Nil) + val result = parCollection.collect() + sleep(parCollection) // slightly extra time as loading classes for the first can take some time + assert(sc.objectFile[Int](parCollection.checkpointFile).collect() === result) + assert(parCollection.dependencies != Nil) + assert(parCollection.collect() === result) + } + + test("BlockRDD") { + val blockId = "id" + val blockManager = SparkEnv.get.blockManager + blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) + val blockRDD = new BlockRDD[String](sc, Array(blockId)) + blockRDD.checkpoint() + val result = blockRDD.collect() + sleep(blockRDD) + assert(sc.objectFile[String](blockRDD.checkpointFile).collect() === result) + assert(blockRDD.dependencies != Nil) + assert(blockRDD.collect() === result) + } + + test("RDDs with one-to-one dependencies") { + testCheckpointing(_.map(x => x.toString)) + testCheckpointing(_.flatMap(x => 1 to x)) + testCheckpointing(_.filter(_ % 2 == 0)) + testCheckpointing(_.sample(false, 0.5, 0)) + testCheckpointing(_.glom()) + testCheckpointing(_.mapPartitions(_.map(_.toString))) + testCheckpointing(r => new MapPartitionsWithSplitRDD(r, + (i: Int, iter: Iterator[Int]) => iter.map(_.toString) )) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), 1000) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x), 1000) + testCheckpointing(_.pipe(Seq("cat"))) + } + + test("ShuffledRDD") { + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _)) + } + + test("UnionRDD") { + testCheckpointing(_.union(sc.makeRDD(5 to 6, 4))) + } + + test("CartesianRDD") { + testCheckpointing(_.cartesian(sc.makeRDD(5 to 6, 4)), 1000) + } + + test("CoalescedRDD") { + testCheckpointing(new CoalescedRDD(_, 2)) + } + + test("CoGroupedRDD") { + val rdd2 = sc.makeRDD(5 to 6, 4).map(x => (x % 2, 1)) + testCheckpointing(rdd1 => rdd1.map(x => (x % 2, 1)).cogroup(rdd2)) + testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2)) + } + + def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { + val parCollection = sc.makeRDD(1 to 4, 4) + val operatedRDD = op(parCollection) + operatedRDD.checkpoint() + val parentRDD = operatedRDD.dependencies.head.rdd + val result = operatedRDD.collect() + sleep(operatedRDD) + //println(parentRDD + ", " + operatedRDD.dependencies.head.rdd ) + assert(sc.objectFile[U](operatedRDD.checkpointFile).collect() === result) + assert(operatedRDD.dependencies.head.rdd != parentRDD) + assert(operatedRDD.collect() === result) + } + + def sleep(rdd: RDD[_]) { + val startTime = System.currentTimeMillis() + val maxWaitTime = 5000 + while(rdd.isCheckpointed == false && System.currentTimeMillis() < startTime + maxWaitTime) { + Thread.sleep(50) + } + assert(rdd.isCheckpointed === true, "Waiting for checkpoint to complete took more than " + maxWaitTime + " ms") + } +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 37a0ff0947..8ac7c8451a 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -19,7 +19,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.master.port") } - + test("basic operations") { sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) @@ -70,10 +70,23 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } - test("checkpointing") { + test("basic checkpointing") { + import java.io.File + val checkpointDir = File.createTempFile("temp", "") + checkpointDir.delete() + sc = new SparkContext("local", "test") - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint() - assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) + sc.setCheckpointDir(checkpointDir.toString) + val parCollection = sc.makeRDD(1 to 4) + val flatMappedRDD = parCollection.flatMap(x => 1 to x) + flatMappedRDD.checkpoint() + assert(flatMappedRDD.dependencies.head.rdd == parCollection) + val result = flatMappedRDD.collect() + Thread.sleep(1000) + assert(flatMappedRDD.dependencies.head.rdd != parCollection) + assert(flatMappedRDD.collect() === result) + + checkpointDir.deleteOnExit() } test("basic caching") { @@ -94,8 +107,8 @@ class RDDSuite extends FunSuite with BeforeAndAfter { List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) // Check that the narrow dependency is also specified correctly - assert(coalesced1.dependencies.head.getParents(0).toList === List(0, 1, 2, 3, 4)) - assert(coalesced1.dependencies.head.getParents(1).toList === List(5, 6, 7, 8, 9)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9)) val coalesced2 = new CoalescedRDD(data, 3) assert(coalesced2.collect().toList === (1 to 10).toList) -- cgit v1.2.3 From 34e569f40e184a6a4f21e9d79b0e8979d8f9541f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 31 Oct 2012 00:56:40 -0700 Subject: Added 'synchronized' to RDD serialization to ensure checkpoint-related changes are reflected atomically in the task closure. Added to tests to ensure that jobs running on an RDD on which checkpointing is in progress does hurt the result of the job. --- core/src/main/scala/spark/RDD.scala | 18 ++++++- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 7 ++- core/src/test/scala/spark/CheckpointSuite.scala | 71 ++++++++++++++++++++++++- 3 files changed, 92 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e272a0ede9..7b59a6f09e 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -1,8 +1,7 @@ package spark -import java.io.EOFException +import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream} import java.net.URL -import java.io.ObjectInputStream import java.util.concurrent.atomic.AtomicLong import java.util.Random import java.util.Date @@ -589,4 +588,19 @@ abstract class RDD[T: ClassManifest]( private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) } + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + synchronized { + oos.defaultWriteObject() + } + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + synchronized { + ois.defaultReadObject() + } + } + } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index b7d843c26d..31774585f4 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -27,7 +27,7 @@ class ShuffledRDD[K, V]( override val partitioner = Some(part) @transient - val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) + var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) override def splits = splits_ @@ -35,4 +35,9 @@ class ShuffledRDD[K, V]( val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } + + override def changeDependencies(newRDD: RDD[_]) { + dependencies_ = Nil + splits_ = newRDD.splits + } } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 0e5ca7dc21..57dc43ddac 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -5,8 +5,10 @@ import java.io.File import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD} import spark.SparkContext._ import storage.StorageLevel +import java.util.concurrent.Semaphore -class CheckpointSuite extends FunSuite with BeforeAndAfter { +class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { + initLogging() var sc: SparkContext = _ var checkpointDir: File = _ @@ -92,6 +94,35 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter { testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2)) } + /** + * This test forces two ResultTasks of the same job to be launched before and after + * the checkpointing of job's RDD is completed. + */ + test("Threading - ResultTasks") { + val op1 = (parCollection: RDD[Int]) => { + parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) + } + val op2 = (firstRDD: RDD[(Int, Int)]) => { + firstRDD.map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) + } + testThreading(op1, op2) + } + + /** + * This test forces two ShuffleMapTasks of the same job to be launched before and after + * the checkpointing of job's RDD is completed. + */ + test("Threading - ShuffleMapTasks") { + val op1 = (parCollection: RDD[Int]) => { + parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) + } + val op2 = (firstRDD: RDD[(Int, Int)]) => { + firstRDD.groupByKey(2).map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) + } + testThreading(op1, op2) + } + + def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { val parCollection = sc.makeRDD(1 to 4, 4) val operatedRDD = op(parCollection) @@ -105,6 +136,44 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter { assert(operatedRDD.collect() === result) } + def testThreading[U: ClassManifest, V: ClassManifest](op1: (RDD[Int]) => RDD[U], op2: (RDD[U]) => RDD[V]) { + + val parCollection = sc.makeRDD(1 to 2, 2) + + // This is the RDD that is to be checkpointed + val firstRDD = op1(parCollection) + val parentRDD = firstRDD.dependencies.head.rdd + firstRDD.checkpoint() + + // This the RDD that uses firstRDD. This is designed to launch a + // ShuffleMapTask that uses firstRDD. + val secondRDD = op2(firstRDD) + + // Starting first job, to initiate the checkpointing + logInfo("\nLaunching 1st job to initiate checkpointing\n") + firstRDD.collect() + + // Checkpointing has started but not completed yet + Thread.sleep(100) + assert(firstRDD.dependencies.head.rdd === parentRDD) + + // Starting second job; first task of this job will be + // launched _before_ firstRDD is marked as checkpointed + // and the second task will be launched _after_ firstRDD + // is marked as checkpointed + logInfo("\nLaunching 2nd job that is designed to launch tasks " + + "before and after checkpointing is complete\n") + val result = secondRDD.collect() + + // Check whether firstRDD has been successfully checkpointed + assert(firstRDD.dependencies.head.rdd != parentRDD) + + logInfo("\nRecomputing 2nd job to verify the results of the previous computation\n") + // Check whether the result in the previous job was correct or not + val correctResult = secondRDD.collect() + assert(result === correctResult) + } + def sleep(rdd: RDD[_]) { val startTime = System.currentTimeMillis() val maxWaitTime = 5000 -- cgit v1.2.3 From d1542387891018914fdd6b647f17f0b05acdd40e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 4 Nov 2012 12:12:06 -0800 Subject: Made checkpointing of dstream graph to work with checkpointing of RDDs. For streams requiring checkpointing of its RDD, the default checkpoint interval is set to 10 seconds. --- core/src/main/scala/spark/RDD.scala | 32 +++-- core/src/main/scala/spark/SparkContext.scala | 13 +- .../main/scala/spark/streaming/Checkpoint.scala | 83 ++++++----- .../src/main/scala/spark/streaming/DStream.scala | 156 ++++++++++++++++----- .../main/scala/spark/streaming/DStreamGraph.scala | 7 +- .../spark/streaming/ReducedWindowedDStream.scala | 36 +++-- .../main/scala/spark/streaming/StateDStream.scala | 45 +----- .../scala/spark/streaming/StreamingContext.scala | 38 +++-- .../src/main/scala/spark/streaming/Time.scala | 4 + .../examples/FileStreamWithCheckpoint.scala | 10 +- .../streaming/examples/TopKWordCountRaw.scala | 5 +- .../spark/streaming/examples/WordCount2.scala | 7 +- .../spark/streaming/examples/WordCountRaw.scala | 6 +- .../scala/spark/streaming/examples/WordMax2.scala | 10 +- .../scala/spark/streaming/CheckpointSuite.scala | 77 +++++++--- .../test/scala/spark/streaming/TestSuiteBase.scala | 37 +++-- .../spark/streaming/WindowOperationsSuite.scala | 4 +- 17 files changed, 367 insertions(+), 203 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7b59a6f09e..63048d5df0 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -119,22 +119,23 @@ abstract class RDD[T: ClassManifest]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Returns the first parent RDD */ - private[spark] def firstParent[U: ClassManifest] = { + protected[spark] def firstParent[U: ClassManifest] = { dependencies.head.rdd.asInstanceOf[RDD[U]] } /** Returns the `i` th parent RDD */ - private[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] + protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] // Variables relating to checkpointing - val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD - var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing - var isCheckpointInProgress = false // set to true when checkpointing is in progress - var isCheckpointed = false // set to true after checkpointing is completed + protected val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD - var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed - var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file - var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD + protected var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing + protected var isCheckpointInProgress = false // set to true when checkpointing is in progress + protected[spark] var isCheckpointed = false // set to true after checkpointing is completed + + protected[spark] var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed + protected var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file + protected var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD // Methods available on all RDDs: @@ -176,6 +177,9 @@ abstract class RDD[T: ClassManifest]( if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) { // do nothing } else if (isCheckpointable) { + if (sc.checkpointDir == null) { + throw new Exception("Checkpoint directory has not been set in the SparkContext.") + } shouldCheckpoint = true } else { throw new Exception(this + " cannot be checkpointed") @@ -183,6 +187,16 @@ abstract class RDD[T: ClassManifest]( } } + def getCheckpointData(): Any = { + synchronized { + if (isCheckpointed) { + checkpointFile + } else { + null + } + } + } + /** * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job * using this RDD has completed (therefore the RDD has been materialized and diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 79ceab5f4f..d7326971a9 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -584,14 +584,15 @@ class SparkContext( * overwriting existing files may be overwritten). The directory will be deleted on exit * if indicated. */ - def setCheckpointDir(dir: String, deleteOnExit: Boolean = false) { + def setCheckpointDir(dir: String, useExisting: Boolean = false) { val path = new Path(dir) val fs = path.getFileSystem(new Configuration()) - if (fs.exists(path)) { - throw new Exception("Checkpoint directory '" + path + "' already exists.") - } else { - fs.mkdirs(path) - if (deleteOnExit) fs.deleteOnExit(path) + if (!useExisting) { + if (fs.exists(path)) { + throw new Exception("Checkpoint directory '" + path + "' already exists.") + } else { + fs.mkdirs(path) + } } checkpointDir = dir } diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 83a43d15cb..cf04c7031e 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -1,6 +1,6 @@ package spark.streaming -import spark.Utils +import spark.{Logging, Utils} import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration @@ -8,13 +8,14 @@ import org.apache.hadoop.conf.Configuration import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} -class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Serializable { +class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) + extends Logging with Serializable { val master = ssc.sc.master val framework = ssc.sc.jobName val sparkHome = ssc.sc.sparkHome val jars = ssc.sc.jars val graph = ssc.graph - val checkpointFile = ssc.checkpointFile + val checkpointDir = ssc.checkpointDir val checkpointInterval = ssc.checkpointInterval validate() @@ -24,22 +25,25 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext assert(framework != null, "Checkpoint.framework is null") assert(graph != null, "Checkpoint.graph is null") assert(checkpointTime != null, "Checkpoint.checkpointTime is null") + logInfo("Checkpoint for time " + checkpointTime + " validated") } - def saveToFile(file: String = checkpointFile) { - val path = new Path(file) + def save(path: String) { + val file = new Path(path, "graph") val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (fs.exists(path)) { - val bkPath = new Path(path.getParent, path.getName + ".bk") - FileUtil.copy(fs, path, fs, bkPath, true, true, conf) - //logInfo("Moved existing checkpoint file to " + bkPath) + val fs = file.getFileSystem(conf) + logDebug("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'") + if (fs.exists(file)) { + val bkFile = new Path(file.getParent, file.getName + ".bk") + FileUtil.copy(fs, file, fs, bkFile, true, true, conf) + logDebug("Moved existing checkpoint file to " + bkFile) } - val fos = fs.create(path) + val fos = fs.create(file) val oos = new ObjectOutputStream(fos) oos.writeObject(this) oos.close() fs.close() + logInfo("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'") } def toBytes(): Array[Byte] = { @@ -50,30 +54,41 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext object Checkpoint { - def loadFromFile(file: String): Checkpoint = { - try { - val path = new Path(file) - val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (!fs.exists(path)) { - throw new Exception("Checkpoint file '" + file + "' does not exist") + def load(path: String): Checkpoint = { + + val fs = new Path(path).getFileSystem(new Configuration()) + val attempts = Seq(new Path(path), new Path(path, "graph"), new Path(path, "graph.bk")) + var lastException: Exception = null + var lastExceptionFile: String = null + + attempts.foreach(file => { + if (fs.exists(file)) { + try { + val fis = fs.open(file) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp.validate() + println("Checkpoint successfully loaded from file " + file) + return cp + } catch { + case e: Exception => + lastException = e + lastExceptionFile = file.toString + } } - val fis = fs.open(path) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() - cp.validate() - cp - } catch { - case e: Exception => - e.printStackTrace() - throw new Exception("Could not load checkpoint file '" + file + "'", e) + }) + + if (lastException == null) { + throw new Exception("Could not load checkpoint from path '" + path + "'") + } else { + throw new Exception("Error loading checkpoint from path '" + lastExceptionFile + "'", lastException) } } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index a4921bb1a2..de51c5d34a 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -13,6 +13,7 @@ import scala.collection.mutable.HashMap import java.util.concurrent.ArrayBlockingQueue import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import scala.Some +import collection.mutable abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) extends Serializable with Logging { @@ -41,53 +42,55 @@ extends Serializable with Logging { */ // RDDs generated, marked as protected[streaming] so that testsuites can access it - protected[streaming] val generatedRDDs = new HashMap[Time, RDD[T]] () + protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () // Time zero for the DStream - protected var zeroTime: Time = null + protected[streaming] var zeroTime: Time = null // Duration for which the DStream will remember each RDD created - protected var rememberDuration: Time = null + protected[streaming] var rememberDuration: Time = null // Storage level of the RDDs in the stream - protected var storageLevel: StorageLevel = StorageLevel.NONE + protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE - // Checkpoint level and checkpoint interval - protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint - protected var checkpointInterval: Time = null + // Checkpoint details + protected[streaming] val mustCheckpoint = false + protected[streaming] var checkpointInterval: Time = null + protected[streaming] val checkpointData = new HashMap[Time, Any]() // Reference to whole DStream graph - protected var graph: DStreamGraph = null + protected[streaming] var graph: DStreamGraph = null def isInitialized = (zeroTime != null) // Duration for which the DStream requires its parent DStream to remember each RDD created def parentRememberDuration = rememberDuration - // Change this RDD's storage level - def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): DStream[T] = { - if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) { - // TODO: not sure this is necessary for DStreams + // Set caching level for the RDDs created by this DStream + def persist(level: StorageLevel): DStream[T] = { + if (this.isInitialized) { throw new UnsupportedOperationException( - "Cannot change storage level of an DStream after it was already assigned a level") + "Cannot change storage level of an DStream after streaming context has started") } - this.storageLevel = storageLevel - this.checkpointLevel = checkpointLevel - this.checkpointInterval = checkpointInterval + this.storageLevel = level this } - // Set caching level for the RDDs created by this DStream - def persist(newLevel: StorageLevel): DStream[T] = persist(newLevel, StorageLevel.NONE, null) - def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY) // Turn on the default caching level for this RDD def cache(): DStream[T] = persist() + def checkpoint(interval: Time): DStream[T] = { + if (isInitialized) { + throw new UnsupportedOperationException( + "Cannot change checkpoint interval of an DStream after streaming context has started") + } + persist() + checkpointInterval = interval + this + } + /** * This method initializes the DStream by setting the "zero" time, based on which * the validity of future times is calculated. This method also recursively initializes @@ -99,7 +102,67 @@ extends Serializable with Logging { + ", cannot initialize it again to " + time) } zeroTime = time + + // Set the checkpoint interval to be slideTime or 10 seconds, which ever is larger + if (mustCheckpoint && checkpointInterval == null) { + checkpointInterval = slideTime.max(Seconds(10)) + logInfo("Checkpoint interval automatically set to " + checkpointInterval) + } + + // Set the minimum value of the rememberDuration if not already set + var minRememberDuration = slideTime + if (checkpointInterval != null && minRememberDuration <= checkpointInterval) { + minRememberDuration = checkpointInterval + slideTime + } + if (rememberDuration == null || rememberDuration < minRememberDuration) { + rememberDuration = minRememberDuration + } + + // Initialize the dependencies dependencies.foreach(_.initialize(zeroTime)) + } + + protected[streaming] def validate() { + assert( + !mustCheckpoint || checkpointInterval != null, + "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " + + " Please use DStream.checkpoint() to set the interval." + ) + + assert( + checkpointInterval == null || checkpointInterval >= slideTime, + "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + + checkpointInterval + " which is lower than its slide time (" + slideTime + "). " + + "Please set it to at least " + slideTime + "." + ) + + assert( + checkpointInterval == null || checkpointInterval.isMultipleOf(slideTime), + "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + + checkpointInterval + " which not a multiple of its slide time (" + slideTime + "). " + + "Please set it to a multiple " + slideTime + "." + ) + + assert( + checkpointInterval == null || storageLevel != StorageLevel.NONE, + "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + + "level has not been set to enable persisting. Please use DStream.persist() to set the " + + "storage level to use memory for better checkpointing performance." + ) + + assert( + checkpointInterval == null || rememberDuration > checkpointInterval, + "The remember duration for " + this.getClass.getSimpleName + " has been set to " + + rememberDuration + " which is not more than the checkpoint interval (" + + checkpointInterval + "). Please set it to higher than " + checkpointInterval + "." + ) + + dependencies.foreach(_.validate()) + + logInfo("Slide time = " + slideTime) + logInfo("Storage level = " + storageLevel) + logInfo("Checkpoint interval = " + checkpointInterval) + logInfo("Remember duration = " + rememberDuration) logInfo("Initialized " + this) } @@ -120,17 +183,12 @@ extends Serializable with Logging { dependencies.foreach(_.setGraph(graph)) } - protected[streaming] def setRememberDuration(duration: Time = slideTime) { - if (duration == null) { - throw new Exception("Duration for remembering RDDs cannot be set to null for " + this) - } else if (rememberDuration != null && duration < rememberDuration) { - logWarning("Duration for remembering RDDs cannot be reduced from " + rememberDuration - + " to " + duration + " for " + this) - } else { + protected[streaming] def setRememberDuration(duration: Time) { + if (duration != null && duration > rememberDuration) { rememberDuration = duration - dependencies.foreach(_.setRememberDuration(parentRememberDuration)) logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) } + dependencies.foreach(_.setRememberDuration(parentRememberDuration)) } /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ @@ -163,12 +221,13 @@ extends Serializable with Logging { if (isTimeValid(time)) { compute(time) match { case Some(newRDD) => - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { + if (storageLevel != StorageLevel.NONE) { newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) + logInfo("Persisting RDD for time " + time + " to " + storageLevel + " at time " + time) + } + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.checkpoint() + logInfo("Marking RDD for time " + time + " for checkpointing at time " + time) } generatedRDDs.put(time, newRDD) Some(newRDD) @@ -199,7 +258,7 @@ extends Serializable with Logging { } } - def forgetOldRDDs(time: Time) { + protected[streaming] def forgetOldRDDs(time: Time) { val keys = generatedRDDs.keys var numForgotten = 0 keys.foreach(t => { @@ -213,12 +272,35 @@ extends Serializable with Logging { dependencies.foreach(_.forgetOldRDDs(time)) } + protected[streaming] def updateCheckpointData() { + checkpointData.clear() + generatedRDDs.foreach { + case(time, rdd) => { + logDebug("Adding checkpointed RDD for time " + time) + val data = rdd.getCheckpointData() + if (data != null) { + checkpointData += ((time, data)) + } + } + } + } + + protected[streaming] def restoreCheckpointData() { + checkpointData.foreach { + case(time, data) => { + logInfo("Restoring checkpointed RDD for time " + time) + generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString))) + } + } + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { logDebug(this.getClass().getSimpleName + ".writeObject used") if (graph != null) { graph.synchronized { if (graph.checkpointInProgress) { + updateCheckpointData() oos.defaultWriteObject() } else { val msg = "Object of " + this.getClass.getName + " is being serialized " + @@ -239,6 +321,8 @@ extends Serializable with Logging { private def readObject(ois: ObjectInputStream) { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() + generatedRDDs = new HashMap[Time, RDD[T]] () + restoreCheckpointData() } /** diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index ac44d7a2a6..f8922ec790 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -22,11 +22,8 @@ final class DStreamGraph extends Serializable with Logging { } zeroTime = time outputStreams.foreach(_.initialize(zeroTime)) - outputStreams.foreach(_.setRememberDuration()) // first set the rememberDuration to default values - if (rememberDuration != null) { - // if custom rememberDuration has been provided, set the rememberDuration - outputStreams.foreach(_.setRememberDuration(rememberDuration)) - } + outputStreams.foreach(_.setRememberDuration(rememberDuration)) + outputStreams.foreach(_.validate) inputStreams.par.foreach(_.start()) } } diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index 1c57d5f855..6df82c0df3 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -21,15 +21,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( partitioner: Partitioner ) extends DStream[(K,V)](parent.ssc) { - if (!_windowTime.isMultipleOf(parent.slideTime)) - throw new Exception("The window duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + assert(_windowTime.isMultipleOf(parent.slideTime), + "The window duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) - if (!_slideTime.isMultipleOf(parent.slideTime)) - throw new Exception("The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + assert(_slideTime.isMultipleOf(parent.slideTime), + "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) - @transient val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + super.persist(StorageLevel.MEMORY_ONLY) + + val reducedStream = parent.reduceByKey(reduceFunc, partitioner) def windowTime: Time = _windowTime @@ -37,15 +41,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( override def slideTime: Time = _slideTime - //TODO: This is wrong. This should depend on the checkpointInterval + override val mustCheckpoint = true + override def parentRememberDuration: Time = rememberDuration + windowTime - override def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): DStream[(K,V)] = { - super.persist(storageLevel, checkpointLevel, checkpointInterval) - reducedStream.persist(storageLevel, checkpointLevel, checkpointInterval) + override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + super.persist(storageLevel) + reducedStream.persist(storageLevel) + this + } + + override def checkpoint(interval: Time): DStream[(K, V)] = { + super.checkpoint(interval) + reducedStream.checkpoint(interval) this } diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index 086752ac55..0211df1343 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -23,51 +23,14 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife rememberPartitioner: Boolean ) extends DStream[(K, S)](parent.ssc) { + super.persist(StorageLevel.MEMORY_ONLY) + override def dependencies = List(parent) override def slideTime = parent.slideTime - override def getOrCompute(time: Time): Option[RDD[(K, S)]] = { - generatedRDDs.get(time) match { - case Some(oldRDD) => { - if (checkpointInterval != null && time > zeroTime && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) { - val r = oldRDD - val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index) - val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) { - override val partitioner = oldRDD.partitioner - } - generatedRDDs.update(time, checkpointedRDD) - logInfo("Checkpointed RDD " + oldRDD.id + " of time " + time + " with its new RDD " + checkpointedRDD.id) - Some(checkpointedRDD) - } else { - Some(oldRDD) - } - } - case None => { - if (isTimeValid(time)) { - compute(time) match { - case Some(newRDD) => { - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) - } - generatedRDDs.put(time, newRDD) - Some(newRDD) - } - case None => { - None - } - } - } else { - None - } - } - } - } - + override val mustCheckpoint = true + override def compute(validTime: Time): Option[RDD[(K, S)]] = { // Try to get the previous state RDD diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index b3148eaa97..3838e84113 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -15,6 +15,8 @@ import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.hadoop.fs.Path +import java.util.UUID class StreamingContext ( sc_ : SparkContext, @@ -26,7 +28,7 @@ class StreamingContext ( def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = this(new SparkContext(master, frameworkName, sparkHome, jars), null) - def this(file: String) = this(null, Checkpoint.loadFromFile(file)) + def this(path: String) = this(null, Checkpoint.load(path)) def this(cp_ : Checkpoint) = this(null, cp_) @@ -51,7 +53,6 @@ class StreamingContext ( val graph: DStreamGraph = { if (isCheckpointPresent) { - cp_.graph.setContext(this) cp_.graph } else { @@ -62,7 +63,15 @@ class StreamingContext ( val nextNetworkInputStreamId = new AtomicInteger(0) var networkInputTracker: NetworkInputTracker = null - private[streaming] var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null + private[streaming] var checkpointDir: String = { + if (isCheckpointPresent) { + sc.setCheckpointDir(cp_.checkpointDir, true) + cp_.checkpointDir + } else { + null + } + } + private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null private[streaming] var receiverJobThread: Thread = null private[streaming] var scheduler: Scheduler = null @@ -75,9 +84,15 @@ class StreamingContext ( graph.setRememberDuration(duration) } - def setCheckpointDetails(file: String, interval: Time) { - checkpointFile = file - checkpointInterval = interval + def checkpoint(dir: String, interval: Time) { + if (dir != null) { + sc.setCheckpointDir(new Path(dir, "rdds-" + UUID.randomUUID.toString).toString) + checkpointDir = dir + checkpointInterval = interval + } else { + checkpointDir = null + checkpointInterval = null + } } private[streaming] def getInitialCheckpoint(): Checkpoint = { @@ -170,16 +185,12 @@ class StreamingContext ( graph.addOutputStream(outputStream) } - def validate() { - assert(graph != null, "Graph is null") - graph.validate() - } - /** * This function starts the execution of the streams. */ def start() { - validate() + assert(graph != null, "Graph is null") + graph.validate() val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true @@ -216,7 +227,8 @@ class StreamingContext ( } def doCheckpoint(currentTime: Time) { - new Checkpoint(this, currentTime).saveToFile(checkpointFile) + new Checkpoint(this, currentTime).save(checkpointDir) + } } diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 9ddb65249a..2ba6502971 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -25,6 +25,10 @@ case class Time(millis: Long) { def isMultipleOf(that: Time): Boolean = (this.millis % that.millis == 0) + def min(that: Time): Time = if (this < that) this else that + + def max(that: Time): Time = if (this > that) this else that + def isZero: Boolean = (this.millis == 0) override def toString: String = (millis.toString + " ms") diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala index df96a811da..21a83c0fde 100644 --- a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala +++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala @@ -10,20 +10,20 @@ object FileStreamWithCheckpoint { def main(args: Array[String]) { if (args.size != 3) { - println("FileStreamWithCheckpoint ") - println("FileStreamWithCheckpoint restart ") + println("FileStreamWithCheckpoint ") + println("FileStreamWithCheckpoint restart ") System.exit(-1) } val directory = new Path(args(1)) - val checkpointFile = args(2) + val checkpointDir = args(2) val ssc: StreamingContext = { if (args(0) == "restart") { // Recreated streaming context from specified checkpoint file - new StreamingContext(checkpointFile) + new StreamingContext(checkpointDir) } else { @@ -34,7 +34,7 @@ object FileStreamWithCheckpoint { // Create new streaming context val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint") ssc_.setBatchDuration(Seconds(1)) - ssc_.setCheckpointDetails(checkpointFile, Seconds(1)) + ssc_.checkpoint(checkpointDir, Seconds(1)) // Setup the streaming computation val inputStream = ssc_.textFileStream(directory.toString) diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index 57fd10f0a5..750cb7445f 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -41,9 +41,8 @@ object TopKWordCountRaw { val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMs)) - //windowedCounts.print() // TODO: something else? + windowedCounts.persist().checkpoint(Milliseconds(chkptMs)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs)) def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = { val taken = new Array[(String, Long)](k) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index 0d2e62b955..865026033e 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -100,10 +100,9 @@ object WordCount2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, - StorageLevel.MEMORY_ONLY_2, - //new StorageLevel(false, true, true, 3), - Milliseconds(chkptMillis.toLong)) + + windowedCounts.persist().checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index abfd12890f..d1ea9a9cd5 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -41,9 +41,9 @@ object WordCountRaw { val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMs)) - //windowedCounts.print() // TODO: something else? + windowedCounts.persist().checkpoint(chkptMs) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs)) + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala index 9d44da2b11..6a9c8a9a69 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala @@ -57,11 +57,13 @@ object WordMax2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) .reduceByKey(add _, reduceTasks.toInt) - .persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMillis.toLong)) + .persist() + .checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) .reduceByKeyAndWindow(max _, Seconds(10), batchDuration, reduceTasks.toInt) - //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - // Milliseconds(chkptMillis.toLong)) + .persist() + .checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 6dcedcf463..dfe31b5771 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -2,52 +2,95 @@ package spark.streaming import spark.streaming.StreamingContext._ import java.io.File +import collection.mutable.ArrayBuffer +import runtime.RichInt +import org.scalatest.BeforeAndAfter +import org.apache.hadoop.fs.Path +import org.apache.commons.io.FileUtils -class CheckpointSuite extends TestSuiteBase { +class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { + + before { + FileUtils.deleteDirectory(new File(checkpointDir)) + } + + after { + FileUtils.deleteDirectory(new File(checkpointDir)) + } override def framework() = "CheckpointSuite" - override def checkpointFile() = "checkpoint" + override def batchDuration() = Seconds(1) + + override def checkpointDir() = "checkpoint" + + override def checkpointInterval() = batchDuration def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], expectedOutput: Seq[Seq[V]], - useSet: Boolean = false + initialNumBatches: Int ) { // Current code assumes that: // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val initialNumBatches = input.size / 2 val nextNumBatches = totalNumBatches - initialNumBatches val initialNumExpectedOutputs = initialNumBatches + val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs // Do half the computation (half the number of batches), create checkpoint file and quit val ssc = setupStreams[U, V](input, operation) val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) - verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet) + verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) Thread.sleep(1000) // Restart and complete the computation from checkpoint file - val sscNew = new StreamingContext(checkpointFile) - sscNew.setCheckpointDetails(null, null) - val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size) - verifyOutput[V](outputNew, expectedOutput, useSet) - - new File(checkpointFile).delete() - new File(checkpointFile + ".bk").delete() - new File("." + checkpointFile + ".crc").delete() - new File("." + checkpointFile + ".bk.crc").delete() + val sscNew = new StreamingContext(checkpointDir) + //sscNew.checkpoint(null, null) + val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs) + verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) } - test("simple per-batch operation") { + + test("map and reduceByKey") { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), - true + 3 ) } + + test("reduceByKeyAndWindowInv") { + val n = 10 + val w = 4 + val input = (1 to n).map(x => Seq("a")).toSeq + val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) + val operation = (st: DStream[String]) => { + st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, Seconds(w), Seconds(1)) + } + for (i <- Seq(3, 5, 7)) { + testCheckpointedOperation(input, operation, output, i) + } + } + + test("updateStateByKey") { + val input = (1 to 10).map(_ => Seq("a")).toSeq + val output = (1 to 10).map(x => Seq(("a", x))).toSeq + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + st.map(x => (x, 1)) + .updateStateByKey[RichInt](updateFunc) + .checkpoint(Seconds(5)) + .map(t => (t._1, t._2.self)) + } + for (i <- Seq(3, 5, 7)) { + testCheckpointedOperation(input, operation, output, i) + } + } + } \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index c9bc454f91..e441feea19 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -5,10 +5,16 @@ import util.ManualClock import collection.mutable.ArrayBuffer import org.scalatest.FunSuite import collection.mutable.SynchronizedBuffer +import java.io.{ObjectInputStream, IOException} + +/** + * This is a input stream just for the testsuites. This is equivalent to a checkpointable, + * replayable, reliable message queue like Kafka. It requires a sequence as input, and + * returns the i_th element at the i_th batch unde manual clock. + */ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int) extends InputDStream[T](ssc_) { - var currentIndex = 0 def start() {} @@ -23,17 +29,32 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ ssc.sc.makeRDD(Seq[T](), numPartitions) } logInfo("Created RDD " + rdd.id) - //currentIndex += 1 Some(rdd) } } +/** + * This is a output stream just for the testsuites. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + */ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected - }) + }) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + output.clear() + } +} +/** + * This is the base trait for Spark Streaming testsuites. This provides basic functionality + * to run user-defined set of input on user-defined stream operations, and verify the output. + */ trait TestSuiteBase extends FunSuite with Logging { System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") @@ -44,7 +65,7 @@ trait TestSuiteBase extends FunSuite with Logging { def batchDuration() = Seconds(1) - def checkpointFile() = null.asInstanceOf[String] + def checkpointDir() = null.asInstanceOf[String] def checkpointInterval() = batchDuration @@ -60,8 +81,8 @@ trait TestSuiteBase extends FunSuite with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - if (checkpointFile != null) { - ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + if (checkpointDir != null) { + ssc.checkpoint(checkpointDir, checkpointInterval()) } // Setup the stream computation @@ -82,8 +103,8 @@ trait TestSuiteBase extends FunSuite with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - if (checkpointFile != null) { - ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + if (checkpointDir != null) { + ssc.checkpoint(checkpointDir, checkpointInterval()) } // Setup the stream computation diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index d7d8d5bd36..e282f0fdd5 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -283,7 +283,9 @@ class WindowOperationsSuite extends TestSuiteBase { test("reduceByKeyAndWindowInv - " + name) { val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist() + s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime) + .persist() + .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing } testOperation(input, operation, expectedOutput, numBatches, true) } -- cgit v1.2.3 From 72b2303f99bd652fc4bdaa929f37731a7ba8f640 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 5 Nov 2012 11:41:36 -0800 Subject: Fixed major bugs in checkpointing. --- core/src/main/scala/spark/SparkContext.scala | 6 +- .../main/scala/spark/streaming/Checkpoint.scala | 24 ++-- .../src/main/scala/spark/streaming/DStream.scala | 47 +++++-- .../main/scala/spark/streaming/DStreamGraph.scala | 36 ++++-- .../src/main/scala/spark/streaming/Scheduler.scala | 1 - .../scala/spark/streaming/StreamingContext.scala | 8 +- .../scala/spark/streaming/CheckpointSuite.scala | 139 ++++++++++++++++----- .../test/scala/spark/streaming/TestSuiteBase.scala | 37 ++++-- .../spark/streaming/WindowOperationsSuite.scala | 6 +- 9 files changed, 217 insertions(+), 87 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d7326971a9..d7b46bee38 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -474,8 +474,10 @@ class SparkContext( /** Shut down the SparkContext. */ def stop() { - dagScheduler.stop() - dagScheduler = null + if (dagScheduler != null) { + dagScheduler.stop() + dagScheduler = null + } taskScheduler = null // TODO: Cache.stop()? env.stop() diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index cf04c7031e..6b4b05103f 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -6,6 +6,7 @@ import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} +import sys.process.processInternal class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) @@ -52,17 +53,17 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) } } -object Checkpoint { +object Checkpoint extends Logging { def load(path: String): Checkpoint = { val fs = new Path(path).getFileSystem(new Configuration()) - val attempts = Seq(new Path(path), new Path(path, "graph"), new Path(path, "graph.bk")) - var lastException: Exception = null - var lastExceptionFile: String = null + val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk")) + var detailedLog: String = "" attempts.foreach(file => { if (fs.exists(file)) { + logInfo("Attempting to load checkpoint from file '" + file + "'") try { val fis = fs.open(file) // ObjectInputStream uses the last defined user-defined class loader in the stack @@ -75,21 +76,18 @@ object Checkpoint { ois.close() fs.close() cp.validate() - println("Checkpoint successfully loaded from file " + file) + logInfo("Checkpoint successfully loaded from file '" + file + "'") return cp } catch { case e: Exception => - lastException = e - lastExceptionFile = file.toString + logError("Error loading checkpoint from file '" + file + "'", e) } + } else { + logWarning("Could not load checkpoint from file '" + file + "' as it does not exist") } - }) - if (lastException == null) { - throw new Exception("Could not load checkpoint from path '" + path + "'") - } else { - throw new Exception("Error loading checkpoint from path '" + lastExceptionFile + "'", lastException) - } + }) + throw new Exception("Could not load checkpoint from path '" + path + "'") } def fromBytes(bytes: Array[Byte]): Checkpoint = { diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index de51c5d34a..2fecbe0acf 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -14,6 +14,8 @@ import java.util.concurrent.ArrayBlockingQueue import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import scala.Some import collection.mutable +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) extends Serializable with Logging { @@ -42,6 +44,7 @@ extends Serializable with Logging { */ // RDDs generated, marked as protected[streaming] so that testsuites can access it + @transient protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () // Time zero for the DStream @@ -112,7 +115,7 @@ extends Serializable with Logging { // Set the minimum value of the rememberDuration if not already set var minRememberDuration = slideTime if (checkpointInterval != null && minRememberDuration <= checkpointInterval) { - minRememberDuration = checkpointInterval + slideTime + minRememberDuration = checkpointInterval * 2 // times 2 just to be sure that the latest checkpoint is not forgetten } if (rememberDuration == null || rememberDuration < minRememberDuration) { rememberDuration = minRememberDuration @@ -265,33 +268,59 @@ extends Serializable with Logging { if (t <= (time - rememberDuration)) { generatedRDDs.remove(t) numForgotten += 1 - //logInfo("Forgot RDD of time " + t + " from " + this) + logInfo("Forgot RDD of time " + t + " from " + this) } }) logInfo("Forgot " + numForgotten + " RDDs from " + this) dependencies.foreach(_.forgetOldRDDs(time)) } + /** + * Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of this stream. + * Along with that it forget old checkpoint files. + */ protected[streaming] def updateCheckpointData() { + + // TODO (tdas): This code can be simplified. Its kept verbose to aid debugging. + val checkpointedRDDs = generatedRDDs.filter(_._2.getCheckpointData() != null) + val removedCheckpointData = checkpointData.filter(x => !generatedRDDs.contains(x._1)) + checkpointData.clear() - generatedRDDs.foreach { - case(time, rdd) => { - logDebug("Adding checkpointed RDD for time " + time) + checkpointedRDDs.foreach { + case (time, rdd) => { val data = rdd.getCheckpointData() - if (data != null) { - checkpointData += ((time, data)) + assert(data != null) + checkpointData += ((time, data)) + logInfo("Added checkpointed RDD " + rdd + " for time " + time + " to stream checkpoint") + } + } + + dependencies.foreach(_.updateCheckpointData()) + // If at least one checkpoint is present, then delete old checkpoints + if (checkpointData.size > 0) { + // Delete the checkpoint RDD files that are not needed any more + removedCheckpointData.foreach { + case (time: Time, file: String) => { + val path = new Path(file) + val fs = path.getFileSystem(new Configuration()) + fs.delete(path, true) + logInfo("Deleted checkpoint file '" + file + "' for time " + time) } } } + + logInfo("Updated checkpoint data") } protected[streaming] def restoreCheckpointData() { + logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs") checkpointData.foreach { case(time, data) => { - logInfo("Restoring checkpointed RDD for time " + time) + logInfo("Restoring checkpointed RDD for time " + time + " from file") generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString))) } } + dependencies.foreach(_.restoreCheckpointData()) } @throws(classOf[IOException]) @@ -300,7 +329,6 @@ extends Serializable with Logging { if (graph != null) { graph.synchronized { if (graph.checkpointInProgress) { - updateCheckpointData() oos.defaultWriteObject() } else { val msg = "Object of " + this.getClass.getName + " is being serialized " + @@ -322,7 +350,6 @@ extends Serializable with Logging { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() generatedRDDs = new HashMap[Time, RDD[T]] () - restoreCheckpointData() } /** diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index f8922ec790..7437f4402d 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -4,7 +4,7 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import collection.mutable.ArrayBuffer import spark.Logging -final class DStreamGraph extends Serializable with Logging { +final private[streaming] class DStreamGraph extends Serializable with Logging { initLogging() private val inputStreams = new ArrayBuffer[InputDStream[_]]() @@ -15,7 +15,7 @@ final class DStreamGraph extends Serializable with Logging { private[streaming] var rememberDuration: Time = null private[streaming] var checkpointInProgress = false - def start(time: Time) { + private[streaming] def start(time: Time) { this.synchronized { if (zeroTime != null) { throw new Exception("DStream graph computation already started") @@ -28,7 +28,7 @@ final class DStreamGraph extends Serializable with Logging { } } - def stop() { + private[streaming] def stop() { this.synchronized { inputStreams.par.foreach(_.stop()) } @@ -40,7 +40,7 @@ final class DStreamGraph extends Serializable with Logging { } } - def setBatchDuration(duration: Time) { + private[streaming] def setBatchDuration(duration: Time) { this.synchronized { if (batchDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + @@ -50,7 +50,7 @@ final class DStreamGraph extends Serializable with Logging { batchDuration = duration } - def setRememberDuration(duration: Time) { + private[streaming] def setRememberDuration(duration: Time) { this.synchronized { if (rememberDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + @@ -60,37 +60,49 @@ final class DStreamGraph extends Serializable with Logging { rememberDuration = duration } - def addInputStream(inputStream: InputDStream[_]) { + private[streaming] def addInputStream(inputStream: InputDStream[_]) { this.synchronized { inputStream.setGraph(this) inputStreams += inputStream } } - def addOutputStream(outputStream: DStream[_]) { + private[streaming] def addOutputStream(outputStream: DStream[_]) { this.synchronized { outputStream.setGraph(this) outputStreams += outputStream } } - def getInputStreams() = inputStreams.toArray + private[streaming] def getInputStreams() = this.synchronized { inputStreams.toArray } - def getOutputStreams() = outputStreams.toArray + private[streaming] def getOutputStreams() = this.synchronized { outputStreams.toArray } - def generateRDDs(time: Time): Seq[Job] = { + private[streaming] def generateRDDs(time: Time): Seq[Job] = { this.synchronized { outputStreams.flatMap(outputStream => outputStream.generateJob(time)) } } - def forgetOldRDDs(time: Time) { + private[streaming] def forgetOldRDDs(time: Time) { this.synchronized { outputStreams.foreach(_.forgetOldRDDs(time)) } } - def validate() { + private[streaming] def updateCheckpointData() { + this.synchronized { + outputStreams.foreach(_.updateCheckpointData()) + } + } + + private[streaming] def restoreCheckpointData() { + this.synchronized { + outputStreams.foreach(_.restoreCheckpointData()) + } + } + + private[streaming] def validate() { this.synchronized { assert(batchDuration != null, "Batch duration has not been set") assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low") diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 7d52e2eddf..2b3f5a4829 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -58,7 +58,6 @@ extends Logging { graph.forgetOldRDDs(time) if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { ssc.doCheckpoint(time) - logInfo("Checkpointed at time " + time) } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 3838e84113..fb36ab9dc9 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -54,6 +54,7 @@ class StreamingContext ( val graph: DStreamGraph = { if (isCheckpointPresent) { cp_.graph.setContext(this) + cp_.graph.restoreCheckpointData() cp_.graph } else { new DStreamGraph() @@ -218,17 +219,16 @@ class StreamingContext ( if (scheduler != null) scheduler.stop() if (networkInputTracker != null) networkInputTracker.stop() if (receiverJobThread != null) receiverJobThread.interrupt() - sc.stop() + sc.stop() + logInfo("StreamingContext stopped successfully") } catch { case e: Exception => logWarning("Error while stopping", e) } - - logInfo("StreamingContext stopped") } def doCheckpoint(currentTime: Time) { + graph.updateCheckpointData() new Checkpoint(this, currentTime).save(checkpointDir) - } } diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index dfe31b5771..aa8ded513c 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -2,11 +2,11 @@ package spark.streaming import spark.streaming.StreamingContext._ import java.io.File -import collection.mutable.ArrayBuffer import runtime.RichInt import org.scalatest.BeforeAndAfter -import org.apache.hadoop.fs.Path import org.apache.commons.io.FileUtils +import collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import util.ManualClock class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { @@ -18,39 +18,83 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { FileUtils.deleteDirectory(new File(checkpointDir)) } - override def framework() = "CheckpointSuite" + override def framework = "CheckpointSuite" - override def batchDuration() = Seconds(1) + override def batchDuration = Milliseconds(500) - override def checkpointDir() = "checkpoint" + override def checkpointDir = "checkpoint" - override def checkpointInterval() = batchDuration + override def checkpointInterval = batchDuration - def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - initialNumBatches: Int - ) { + override def actuallyWait = true - // Current code assumes that: - // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val nextNumBatches = totalNumBatches - initialNumBatches - val initialNumExpectedOutputs = initialNumBatches - val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + test("basic stream+rdd recovery") { - // Do half the computation (half the number of batches), create checkpoint file and quit - val ssc = setupStreams[U, V](input, operation) - val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) - verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) - Thread.sleep(1000) + assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - // Restart and complete the computation from checkpoint file + val checkpointingInterval = Seconds(2) + + // this ensure checkpointing occurs at least once + val firstNumBatches = (checkpointingInterval.millis / batchDuration.millis) * 2 + val secondNumBatches = firstNumBatches + + // Setup the streams + val input = (1 to 10).map(_ => Seq("a")).toSeq + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + st.map(x => (x, 1)) + .updateStateByKey[RichInt](updateFunc) + .checkpoint(checkpointingInterval) + .map(t => (t._1, t._2.self)) + } + val ssc = setupStreams(input, operation) + val stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head + + // Run till a time such that at least one RDD in the stream should have been checkpointed + ssc.start() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + logInfo("Manual clock before advancing = " + clock.time) + for (i <- 1 to firstNumBatches.toInt) { + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + logInfo("Manual clock after advancing = " + clock.time) + Thread.sleep(batchDuration.milliseconds) + + // Check whether some RDD has been checkpointed or not + logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.mkString(",\n") + "]") + assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream") + stateStream.checkpointData.foreach { + case (time, data) => { + val file = new File(data.toString) + assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " does not exist") + } + } + val checkpointFiles = stateStream.checkpointData.map(x => new File(x._2.toString)) + + // Run till a further time such that previous checkpoint files in the stream would be deleted + logInfo("Manual clock before advancing = " + clock.time) + for (i <- 1 to secondNumBatches.toInt) { + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + logInfo("Manual clock after advancing = " + clock.time) + Thread.sleep(batchDuration.milliseconds) + + // Check whether the earlier checkpoint files are deleted + checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) + + // Restart stream computation using the checkpoint file and check whether + // checkpointed RDDs have been restored or not + ssc.stop() val sscNew = new StreamingContext(checkpointDir) - //sscNew.checkpoint(null, null) - val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs) - verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) + val stateStreamNew = sscNew.graph.getOutputStreams().head.dependencies.head.dependencies.head + logInfo("Restored data of state stream = \n[" + stateStreamNew.generatedRDDs.mkString("\n") + "]") + assert(!stateStreamNew.generatedRDDs.isEmpty, "No restored RDDs in state stream") + sscNew.stop() } @@ -69,9 +113,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val input = (1 to n).map(x => Seq("a")).toSeq val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) val operation = (st: DStream[String]) => { - st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, Seconds(w), Seconds(1)) + st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, batchDuration * 4, batchDuration) } - for (i <- Seq(3, 5, 7)) { + for (i <- Seq(2, 3, 4)) { testCheckpointedOperation(input, operation, output, i) } } @@ -85,12 +129,45 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { } st.map(x => (x, 1)) .updateStateByKey[RichInt](updateFunc) - .checkpoint(Seconds(5)) + .checkpoint(Seconds(2)) .map(t => (t._1, t._2.self)) } - for (i <- Seq(3, 5, 7)) { + for (i <- Seq(2, 3, 4)) { testCheckpointedOperation(input, operation, output, i) } } + + + def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + initialNumBatches: Int + ) { + + // Current code assumes that: + // number of inputs = number of outputs = number of batches to be run + val totalNumBatches = input.size + val nextNumBatches = totalNumBatches - initialNumBatches + val initialNumExpectedOutputs = initialNumBatches + val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + + // Do half the computation (half the number of batches), create checkpoint file and quit + + val ssc = setupStreams[U, V](input, operation) + val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) + verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) + Thread.sleep(1000) + + // Restart and complete the computation from checkpoint file + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation " + + "\n-------------------------------------------\n" + ) + val sscNew = new StreamingContext(checkpointDir) + val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs) + verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) + } } \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index e441feea19..b8c7f99603 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -57,21 +57,21 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu */ trait TestSuiteBase extends FunSuite with Logging { - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + def framework = "TestSuiteBase" - def framework() = "TestSuiteBase" + def master = "local[2]" - def master() = "local[2]" + def batchDuration = Seconds(1) - def batchDuration() = Seconds(1) + def checkpointDir = null.asInstanceOf[String] - def checkpointDir() = null.asInstanceOf[String] + def checkpointInterval = batchDuration - def checkpointInterval() = batchDuration + def numInputPartitions = 2 - def numInputPartitions() = 2 + def maxWaitTimeMillis = 10000 - def maxWaitTimeMillis() = 10000 + def actuallyWait = false def setupStreams[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], @@ -82,7 +82,7 @@ trait TestSuiteBase extends FunSuite with Logging { val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) if (checkpointDir != null) { - ssc.checkpoint(checkpointDir, checkpointInterval()) + ssc.checkpoint(checkpointDir, checkpointInterval) } // Setup the stream computation @@ -104,7 +104,7 @@ trait TestSuiteBase extends FunSuite with Logging { val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) if (checkpointDir != null) { - ssc.checkpoint(checkpointDir, checkpointInterval()) + ssc.checkpoint(checkpointDir, checkpointInterval) } // Setup the stream computation @@ -118,12 +118,19 @@ trait TestSuiteBase extends FunSuite with Logging { ssc } + /** + * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and + * returns the collected output. It will wait until `numExpectedOutput` number of + * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached. + */ def runStreams[V: ClassManifest]( ssc: StreamingContext, numBatches: Int, numExpectedOutput: Int ): Seq[Seq[V]] = { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + assert(numBatches > 0, "Number of batches to run stream computation is zero") assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) @@ -139,7 +146,15 @@ trait TestSuiteBase extends FunSuite with Logging { // Advance manual clock val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.time) - clock.addToTime(numBatches * batchDuration.milliseconds) + if (actuallyWait) { + for (i <- 1 to numBatches) { + logInfo("Actually waiting for " + batchDuration) + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + } else { + clock.addToTime(numBatches * batchDuration.milliseconds) + } logInfo("Manual clock after advancing = " + clock.time) // Wait until expected number of output items have been generated diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index e282f0fdd5..3e20e16708 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -5,11 +5,11 @@ import collection.mutable.ArrayBuffer class WindowOperationsSuite extends TestSuiteBase { - override def framework() = "WindowOperationsSuite" + override def framework = "WindowOperationsSuite" - override def maxWaitTimeMillis() = 20000 + override def maxWaitTimeMillis = 20000 - override def batchDuration() = Seconds(1) + override def batchDuration = Seconds(1) val largerSlideInput = Seq( Seq(("a", 1)), -- cgit v1.2.3 From 355c8e4b17cc3e67b1e18cc24e74d88416b5779b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 9 Nov 2012 16:28:45 -0800 Subject: Fixed deadlock in BlockManager. --- .../main/scala/spark/storage/BlockManager.scala | 111 ++++++++++----------- .../src/main/scala/spark/storage/MemoryStore.scala | 79 +++++++++------ 2 files changed, 101 insertions(+), 89 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index bd9155ef29..8c7b1417be 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -50,16 +50,6 @@ private[spark] case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message) - -private[spark] class BlockLocker(numLockers: Int) { - private val hashLocker = Array.fill(numLockers)(new Object()) - - def getLock(blockId: String): Object = { - return hashLocker(math.abs(blockId.hashCode % numLockers)) - } -} - - private[spark] class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, maxMemory: Long) extends Logging { @@ -87,10 +77,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - private val NUM_LOCKS = 337 - private val locker = new BlockLocker(NUM_LOCKS) - - private val blockInfo = new ConcurrentHashMap[String, BlockInfo]() + private val blockInfo = new ConcurrentHashMap[String, BlockInfo](1000) private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private[storage] val diskStore: BlockStore = @@ -110,7 +97,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val maxBytesInFlight = System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 + // Whether to compress broadcast variables that are stored val compressBroadcast = System.getProperty("spark.broadcast.compress", "true").toBoolean + // Whether to compress shuffle output that are stored val compressShuffle = System.getProperty("spark.shuffle.compress", "true").toBoolean // Whether to compress RDD partitions that are stored serialized val compressRdds = System.getProperty("spark.rdd.compress", "false").toBoolean @@ -150,28 +139,27 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. */ def reportBlockStatus(blockId: String) { - locker.getLock(blockId).synchronized { - val curLevel = blockInfo.get(blockId) match { - case null => - StorageLevel.NONE - case info => + val (curLevel, inMemSize, onDiskSize) = blockInfo.get(blockId) match { + case null => + (StorageLevel.NONE, 0L, 0L) + case info => + info.synchronized { info.level match { case null => - StorageLevel.NONE + (StorageLevel.NONE, 0L, 0L) case level => val inMem = level.useMemory && memoryStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) - new StorageLevel(onDisk, inMem, level.deserialized, level.replication) + ( + new StorageLevel(onDisk, inMem, level.deserialized, level.replication), + if (inMem) memoryStore.getSize(blockId) else 0L, + if (onDisk) diskStore.getSize(blockId) else 0L + ) } - } - master.mustHeartBeat(HeartBeat( - blockManagerId, - blockId, - curLevel, - if (curLevel.useMemory) memoryStore.getSize(blockId) else 0L, - if (curLevel.useDisk) diskStore.getSize(blockId) else 0L)) - logDebug("Told master about block " + blockId) + } } + master.mustHeartBeat(HeartBeat(blockManagerId, blockId, curLevel, inMemSize, onDiskSize)) + logDebug("Told master about block " + blockId) } /** @@ -213,9 +201,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - locker.getLock(blockId).synchronized { - val info = blockInfo.get(blockId) - if (info != null) { + val info = blockInfo.get(blockId) + if (info != null) { + info.synchronized { info.waitForReady() // In case the block is still being put() by another thread val level = info.level logDebug("Level for block " + blockId + " is " + level) @@ -273,9 +261,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } } - } else { - logDebug("Block " + blockId + " not registered locally") } + } else { + logDebug("Block " + blockId + " not registered locally") } return None } @@ -298,9 +286,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - locker.getLock(blockId).synchronized { - val info = blockInfo.get(blockId) - if (info != null) { + val info = blockInfo.get(blockId) + if (info != null) { + info.synchronized { info.waitForReady() // In case the block is still being put() by another thread val level = info.level logDebug("Level for block " + blockId + " is " + level) @@ -338,10 +326,11 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new Exception("Block " + blockId + " not found on disk, though it should be") } } - } else { - logDebug("Block " + blockId + " not registered locally") } + } else { + logDebug("Block " + blockId + " not registered locally") } + return None } @@ -583,7 +572,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Size of the block in bytes (to return to caller) var size = 0L - locker.getLock(blockId).synchronized { + myInfo.synchronized { logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") @@ -681,7 +670,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m null } - locker.getLock(blockId).synchronized { + myInfo.synchronized { logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") @@ -779,26 +768,30 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { logInfo("Dropping block " + blockId + " from memory") - locker.getLock(blockId).synchronized { - val info = blockInfo.get(blockId) - val level = info.level - if (level.useDisk && !diskStore.contains(blockId)) { - logInfo("Writing block " + blockId + " to disk") - data match { - case Left(elements) => - diskStore.putValues(blockId, elements, level, false) - case Right(bytes) => - diskStore.putBytes(blockId, bytes, level) + val info = blockInfo.get(blockId) + if (info != null) { + info.synchronized { + val level = info.level + if (level.useDisk && !diskStore.contains(blockId)) { + logInfo("Writing block " + blockId + " to disk") + data match { + case Left(elements) => + diskStore.putValues(blockId, elements, level, false) + case Right(bytes) => + diskStore.putBytes(blockId, bytes, level) + } + } + memoryStore.remove(blockId) + if (info.tellMaster) { + reportBlockStatus(blockId) + } + if (!level.useDisk) { + // The block is completely gone from this node; forget it so we can put() it again later. + blockInfo.remove(blockId) } } - memoryStore.remove(blockId) - if (info.tellMaster) { - reportBlockStatus(blockId) - } - if (!level.useDisk) { - // The block is completely gone from this node; forget it so we can put() it again later. - blockInfo.remove(blockId) - } + } else { + // The block has already been dropped } } diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala index 773970446a..09769d1f7d 100644 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -17,13 +17,16 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) private var currentMemory = 0L + // Object used to ensure that only one thread is putting blocks and if necessary, dropping + // blocks from the memory store. + private val putLock = new Object() logInfo("MemoryStore started with capacity %s.".format(Utils.memoryBytesToString(maxMemory))) def freeMemory: Long = maxMemory - currentMemory override def getSize(blockId: String): Long = { - synchronized { + entries.synchronized { entries.get(blockId).size } } @@ -38,8 +41,6 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) tryToPut(blockId, elements, sizeEstimate, true) } else { val entry = new Entry(bytes, bytes.limit, false) - ensureFreeSpace(blockId, bytes.limit) - synchronized { entries.put(blockId, entry) } tryToPut(blockId, bytes, bytes.limit, false) } } @@ -63,7 +64,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def getBytes(blockId: String): Option[ByteBuffer] = { - val entry = synchronized { + val entry = entries.synchronized { entries.get(blockId) } if (entry == null) { @@ -76,7 +77,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def getValues(blockId: String): Option[Iterator[Any]] = { - val entry = synchronized { + val entry = entries.synchronized { entries.get(blockId) } if (entry == null) { @@ -90,7 +91,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def remove(blockId: String) { - synchronized { + entries.synchronized { val entry = entries.get(blockId) if (entry != null) { entries.remove(blockId) @@ -104,7 +105,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def clear() { - synchronized { + entries.synchronized { entries.clear() } logInfo("MemoryStore cleared") @@ -125,12 +126,22 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Try to put in a set of values, if we can free up enough space. The value should either be * an ArrayBuffer if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) * size must also be passed by the caller. + * + * Locks on the object putLock to ensure that all the put requests and its associated block + * dropping is done by only on thread at a time. Otherwise while one thread is dropping + * blocks to free memory for one block, another thread may use up the freed space for + * another block. */ private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = { - synchronized { + // TODO: Its possible to optimize the locking by locking entries only when selecting blocks + // to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been + // released, it must be ensured that those to-be-dropped blocks are not double counted for + // freeing up more space for another block that needs to be put. Only then the actually dropping + // of blocks (and writing to disk if necessary) can proceed in parallel. + putLock.synchronized { if (ensureFreeSpace(blockId, size)) { val entry = new Entry(value, size, deserialized) - entries.put(blockId, entry) + entries.synchronized { entries.put(blockId, entry) } currentMemory += size if (deserialized) { logInfo("Block %s stored as values to memory (estimated size %s, free %s)".format( @@ -160,8 +171,8 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that * don't fit into memory that we want to avoid). * - * Assumes that a lock on the MemoryStore is held by the caller. (Otherwise, the freed space - * might fill up before the caller puts in their new value.) + * Assumes that a lock is held by the caller to ensure only one thread is dropping blocks. + * Otherwise, the freed space may fill up before the caller puts in their new value. */ private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = { logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( @@ -172,36 +183,44 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) return false } - // TODO: This should relinquish the lock on the MemoryStore while flushing out old blocks - // in order to allow parallelism in writing to disk if (maxMemory - currentMemory < space) { val rddToAdd = getRddId(blockIdToAdd) val selectedBlocks = new ArrayBuffer[String]() var selectedMemory = 0L - val iterator = entries.entrySet().iterator() - while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { - val pair = iterator.next() - val blockId = pair.getKey - if (rddToAdd != null && rddToAdd == getRddId(blockId)) { - logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + - "block from the same RDD") - return false + // This is synchronized to ensure that the set of entries is not changed + // (because of getValue or getBytes) while traversing the iterator, as that + // can lead to exceptions. + entries.synchronized { + val iterator = entries.entrySet().iterator() + while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { + val pair = iterator.next() + val blockId = pair.getKey + if (rddToAdd != null && rddToAdd == getRddId(blockId)) { + logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + + "block from the same RDD") + return false + } + selectedBlocks += blockId + selectedMemory += pair.getValue.size } - selectedBlocks += blockId - selectedMemory += pair.getValue.size } if (maxMemory - (currentMemory - selectedMemory) >= space) { logInfo(selectedBlocks.size + " blocks selected for dropping") for (blockId <- selectedBlocks) { - val entry = entries.get(blockId) - val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + val entry = entries.synchronized { entries.get(blockId) } + // This should never be null as only one thread should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + val data = if (entry.deserialized) { + Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + } + blockManager.dropFromMemory(blockId, data) } - blockManager.dropFromMemory(blockId, data) } return true } else { @@ -212,7 +231,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def contains(blockId: String): Boolean = { - synchronized { entries.containsKey(blockId) } + entries.synchronized { entries.containsKey(blockId) } } } -- cgit v1.2.3 From 04e9e9d93c512f856116bc2c99c35dfb48b4adee Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 11 Nov 2012 08:54:21 -0800 Subject: Refactored BlockManagerMaster (not BlockManagerMasterActor) to simplify the code and fix live lock problem in unlimited attempts to contact the master. Also added testcases in the BlockManagerSuite to test BlockManagerMaster methods getPeers and getLocations. --- .../main/scala/spark/storage/BlockManager.scala | 14 +- .../scala/spark/storage/BlockManagerMaster.scala | 281 +++++++-------------- .../scala/spark/storage/BlockManagerSuite.scala | 30 ++- 3 files changed, 127 insertions(+), 198 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 8c7b1417be..70d6d8369d 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -120,8 +120,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * BlockManagerWorker actor. */ private def initialize() { - master.mustRegisterBlockManager( - RegisterBlockManager(blockManagerId, maxMemory)) + master.registerBlockManager(blockManagerId, maxMemory) BlockManagerWorker.startBlockManagerWorker(this) } @@ -158,7 +157,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } } - master.mustHeartBeat(HeartBeat(blockManagerId, blockId, curLevel, inMemSize, onDiskSize)) + master.updateBlockInfo(blockManagerId, blockId, curLevel, inMemSize, onDiskSize) logDebug("Told master about block " + blockId) } @@ -167,7 +166,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def getLocations(blockId: String): Seq[String] = { val startTimeMs = System.currentTimeMillis - var managers = master.mustGetLocations(GetLocations(blockId)) + var managers = master.getLocations(blockId) val locations = managers.map(_.ip) logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations @@ -178,8 +177,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis - val locations = master.mustGetLocationsMultipleBlockIds( - GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray + val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -343,7 +341,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } logDebug("Getting remote block " + blockId) // Get locations of block - val locations = master.mustGetLocations(GetLocations(blockId)) + val locations = master.getLocations(blockId) // Get block from remote locations for (loc <- locations) { @@ -721,7 +719,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) if (cachedPeers == null) { - cachedPeers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + cachedPeers = master.getPeers(blockManagerId, level.replication - 1) } for (peer: BlockManagerId <- cachedPeers) { val start = System.nanoTime diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index b3345623b3..4d5ee8318c 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -26,7 +26,7 @@ case class RegisterBlockManager( extends ToBlockManagerMaster private[spark] -class HeartBeat( +class UpdateBlockInfo( var blockManagerId: BlockManagerId, var blockId: String, var storageLevel: StorageLevel, @@ -57,17 +57,17 @@ class HeartBeat( } private[spark] -object HeartBeat { +object UpdateBlockInfo { def apply(blockManagerId: BlockManagerId, blockId: String, storageLevel: StorageLevel, memSize: Long, - diskSize: Long): HeartBeat = { - new HeartBeat(blockManagerId, blockId, storageLevel, memSize, diskSize) + diskSize: Long): UpdateBlockInfo = { + new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize) } // For pattern-matching - def unapply(h: HeartBeat): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { + def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) } } @@ -182,8 +182,8 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor case RegisterBlockManager(blockManagerId, maxMemSize) => register(blockManagerId, maxMemSize) - case HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) => - heartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) + case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) case GetLocations(blockId) => getLocations(blockId) @@ -233,7 +233,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor sender ! true } - private def heartBeat( + private def updateBlockInfo( blockManagerId: BlockManagerId, blockId: String, storageLevel: StorageLevel, @@ -245,7 +245,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (blockId == null) { blockManagerInfo(blockManagerId).updateLastSeenMs() - logDebug("Got in heartBeat 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Got in updateBlockInfo 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) sender ! true } @@ -350,211 +350,124 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean) extends Logging { - val AKKA_ACTOR_NAME: String = "BlockMasterManager" - val REQUEST_RETRY_INTERVAL_MS = 100 - val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost") - val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt - val DEFAULT_MANAGER_IP: String = Utils.localHostName() - val DEFAULT_MANAGER_PORT: String = "10902" - + val actorName = "BlockMasterManager" val timeout = 10.seconds - var masterActor: ActorRef = null + val maxAttempts = 5 - if (isMaster) { - masterActor = actorSystem.actorOf( - Props(new BlockManagerMasterActor(isLocal)), name = AKKA_ACTOR_NAME) + var masterActor = if (isMaster) { + val actor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), name = actorName) logInfo("Registered BlockManagerMaster Actor") + actor } else { - val url = "akka://spark@%s:%s/user/%s".format( - DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT, AKKA_ACTOR_NAME) + val host = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val url = "akka://spark@%s:%s/user/%s".format(host, port, actorName) + val actor = actorSystem.actorFor(url) logInfo("Connecting to BlockManagerMaster: " + url) - masterActor = actorSystem.actorFor(url) + actor } - def stop() { - if (masterActor != null) { - communicate(StopBlockManagerMaster) - masterActor = null - logInfo("BlockManagerMaster stopped") + /** + * Send a message to the master actor and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + private def ask[T](message: Any): T = { + // TODO: Consider removing multiple attempts + if (masterActor == null) { + throw new SparkException("Error sending message to BlockManager as masterActor is null " + + "[message = " + message + "]") } - } - - // Send a message to the master actor and get its result within a default timeout, or - // throw a SparkException if this fails. - def askMaster(message: Any): Any = { - try { - val future = masterActor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with BlockManagerMaster", e) + var attempts = 0 + var lastException: Exception = null + while (attempts < maxAttempts) { + attempts += 1 + try { + val future = masterActor.ask(message)(timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new Exception("BlockManagerMaster returned null") + } + return result.asInstanceOf[T] + } catch { + case ie: InterruptedException => + throw ie + case e: Exception => + lastException = e + logWarning( + "Error sending message to BlockManagerMaster in " + attempts + " attempts", e) + } + Thread.sleep(100) } + throw new SparkException( + "Error sending message to BlockManagerMaster [message = " + message + "]", lastException) } - // Send a one-way message to the master actor, to which we expect it to reply with true. - def communicate(message: Any) { - if (askMaster(message) != true) { - throw new SparkException("Error reply received from BlockManagerMaster") + /** + * Send a one-way message to the master actor, to which we expect it to reply with true + */ + private def tell(message: Any) { + if (!ask[Boolean](message)) { + throw new SparkException("Telling master a message returned false") } } - def notifyADeadHost(host: String) { - communicate(RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)) - logInfo("Removed " + host + " successfully in notifyADeadHost") - } - - def mustRegisterBlockManager(msg: RegisterBlockManager) { + /** + * Register the BlockManager's id with the master + */ + def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long) { logInfo("Trying to register BlockManager") - while (! syncRegisterBlockManager(msg)) { - logWarning("Failed to register " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - logInfo("Done registering BlockManager") - } - - def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = { - //val masterActor = RemoteActor.select(node, name) - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - communicate(msg) - logInfo("BlockManager registered successfully @ syncRegisterBlockManager") - logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return true - } catch { - case e: Exception => - logError("Failed in syncRegisterBlockManager", e) - return false - } + tell(RegisterBlockManager(blockManagerId, maxMemSize)) + logInfo("Registered BlockManager") } - def mustHeartBeat(msg: HeartBeat) { - while (! syncHeartBeat(msg)) { - logWarning("Failed to send heartbeat" + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - } - - def syncHeartBeat(msg: HeartBeat): Boolean = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncHeartBeat " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs)) - - try { - communicate(msg) - logDebug("Heartbeat sent successfully") - logDebug("Got in syncHeartBeat 1 " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs)) - return true - } catch { - case e: Exception => - logError("Failed in syncHeartBeat", e) - return false - } + def updateBlockInfo( + blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long + ) { + tell(UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) + logInfo("Updated info of block " + blockId) } - def mustGetLocations(msg: GetLocations): Seq[BlockManagerId] = { - var res = syncGetLocations(msg) - while (res == null) { - logInfo("Failed to get locations " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocations(msg) - } - return res + /** Get locations of the blockId from the master */ + def getLocations(blockId: String): Seq[BlockManagerId] = { + ask[Seq[BlockManagerId]](GetLocations(blockId)) } - def syncGetLocations(msg: GetLocations): Seq[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[ArrayBuffer[BlockManagerId]] - if (answer != null) { - logDebug("GetLocations successful") - logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetLocations") - return null - } - } catch { - case e: Exception => - logError("GetLocations failed", e) - return null - } + /** Get locations of multiple blockIds from the master */ + def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + ask[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } - def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg) - while (res == null) { - logWarning("Failed to GetLocationsMultipleBlockIds " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocationsMultipleBlockIds(msg) + /** Get ids of other nodes in the cluster from the master */ + def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { + val result = ask[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) + if (result.length != numPeers) { + throw new SparkException( + "Error getting peers, only got " + result.size + " instead of " + numPeers) } - return res + result } - def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[Seq[Seq[BlockManagerId]]] - if (answer != null) { - logDebug("GetLocationsMultipleBlockIds successful") - logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + - Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetLocationsMultipleBlockIds") - return null - } - } catch { - case e: Exception => - logError("GetLocationsMultipleBlockIds failed", e) - return null - } + /** Notify the master of a dead node */ + def notifyADeadHost(host: String) { + tell(RemoveHost(host + ":10902")) + logInfo("Told BlockManagerMaster to remove dead host " + host) } - def mustGetPeers(msg: GetPeers): Seq[BlockManagerId] = { - var res = syncGetPeers(msg) - while ((res == null) || (res.length != msg.size)) { - logInfo("Failed to get peers " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetPeers(msg) - } - - return res + /** Get the memory status form the master */ + def getMemoryStatus(): Map[BlockManagerId, (Long, Long)] = { + ask[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } - def syncGetPeers(msg: GetPeers): Seq[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[Seq[BlockManagerId]] - if (answer != null) { - logDebug("GetPeers successful") - logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetPeers") - return null - } - } catch { - case e: Exception => - logError("GetPeers failed", e) - return null + /** Stop the master actor, called only on the Spark master node */ + def stop() { + if (masterActor != null) { + tell(StopBlockManagerMaster) + masterActor = null + logInfo("BlockManagerMaster stopped") } } - - def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askMaster(GetMemoryStatus).asInstanceOf[Map[BlockManagerId, (Long, Long)]] - } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index b9c19e61cd..0e78228134 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -20,9 +20,11 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var oldOops: String = null // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + System.setProperty("spark.kryoserializer.buffer.mb", "1") val serializer = new KryoSerializer before { + actorSystem = ActorSystem("test") master = new BlockManagerMaster(actorSystem, true, true) @@ -55,7 +57,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } } - test("manager-master interaction") { + test("master + 1 manager interaction") { store = new BlockManager(master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -72,17 +74,33 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.getSingle("a3") != None, "a3 was not in store") // Checking whether master knows about the blocks or not - assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") - assert(master.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2") - assert(master.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") + assert(master.getLocations("a1").size === 1, "master was not told about a1") + assert(master.getLocations("a2").size === 1, "master was not told about a2") + assert(master.getLocations("a3").size === 0, "master was told about a3") // Drop a1 and a2 from memory; this should be reported back to the master store.dropFromMemory("a1", null) store.dropFromMemory("a2", null) assert(store.getSingle("a1") === None, "a1 not removed from store") assert(store.getSingle("a2") === None, "a2 not removed from store") - assert(master.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1") - assert(master.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2") + assert(master.getLocations("a1").size === 0, "master did not remove a1") + assert(master.getLocations("a2").size === 0, "master did not remove a2") + } + + test("master + 2 managers interaction") { + store = new BlockManager(master, serializer, 2000) + val otherStore = new BlockManager(master, new KryoSerializer, 2000) + + val peers = master.getPeers(store.blockManagerId, 1) + assert(peers.size === 1, "master did not return the other manager as a peer") + assert(peers.head === otherStore.blockManagerId, "peer returned by master is not the other manager") + + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2) + otherStore.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2) + assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1") + assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2") } test("in-memory LRU storage") { -- cgit v1.2.3 From 4a1be7e0dbf0031d85b91dc1132fe101d87ba097 Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 12 Nov 2012 10:56:35 -0800 Subject: Refactor BlockManager UI and adding worker details. --- core/src/main/scala/spark/RDD.scala | 7 +- core/src/main/scala/spark/SparkContext.scala | 2 - .../scala/spark/storage/BlockManagerMaster.scala | 11 ++- .../main/scala/spark/storage/BlockManagerUI.scala | 51 ++++---------- .../main/scala/spark/storage/StorageLevel.scala | 9 +++ .../main/scala/spark/storage/StorageUtils.scala | 78 ++++++++++++++++++++++ core/src/main/twirl/spark/storage/index.scala.html | 22 ++++-- core/src/main/twirl/spark/storage/rdd.scala.html | 35 ++++++---- .../main/twirl/spark/storage/rdd_row.scala.html | 18 ----- .../main/twirl/spark/storage/rdd_table.scala.html | 16 ++++- .../twirl/spark/storage/worker_table.scala.html | 24 +++++++ 11 files changed, 186 insertions(+), 87 deletions(-) create mode 100644 core/src/main/scala/spark/storage/StorageUtils.scala delete mode 100644 core/src/main/twirl/spark/storage/rdd_row.scala.html create mode 100644 core/src/main/twirl/spark/storage/worker_table.scala.html (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index dc757dc6aa..3669bda2d2 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -86,6 +86,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial @transient val dependencies: List[Dependency[_]] // Methods available on all RDDs: + + // A friendly name for this RDD + var name: String = null /** Record user function generating this RDD. */ private[spark] val origin = Utils.getSparkCallSite @@ -108,8 +111,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial private var storageLevel: StorageLevel = StorageLevel.NONE /* Assign a name to this RDD */ - def name(name: String) = { - sc.rddNames(this.id) = name + def setName(_name: String) = { + name = _name this } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 71c9dcd017..7ea0f6f9e0 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -113,8 +113,6 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new ConcurrentHashMap[Int, RDD[_]]() - // A HashMap for friendly RDD Names - private[spark] val rddNames = new ConcurrentHashMap[Int, String]() // Add each JAR given through the constructor jars.foreach { addJar(_) } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 3fc9b629c1..beafdda9d1 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -4,7 +4,7 @@ import java.io._ import java.util.{HashMap => JHashMap} import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.Random import akka.actor._ @@ -95,10 +95,7 @@ private[spark] case class GetStorageStatus extends ToBlockManagerMaster private[spark] -case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) - -private[spark] -case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, remainingMem: Long, blocks: Map[String, BlockStatus]) +case class BlockStatus(blockManagerId: BlockManagerId, storageLevel: StorageLevel, memSize: Long, diskSize: Long) private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { @@ -135,7 +132,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (storageLevel.isValid) { // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) + _blocks.put(blockId, BlockStatus(blockManagerId, storageLevel, memSize, diskSize)) if (storageLevel.useMemory) { _remainingMem -= memSize logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( @@ -237,7 +234,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor private def getStorageStatus() { val res = blockManagerInfo.map { case(blockManagerId, info) => - StorageStatus(blockManagerId, info.maxMem, info.remainingMem, info.blocks.asScala) + StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap) } sender ! res } diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 635c096c87..35cbd59280 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -12,6 +12,7 @@ import scala.collection.mutable.ArrayBuffer import spark.{Logging, SparkContext, SparkEnv} import spark.util.AkkaUtils + private[spark] object BlockManagerUI extends Logging { @@ -32,9 +33,6 @@ object BlockManagerUI extends Logging { } -private[spark] -case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numPartitions: Int, memSize: Long, diskSize: Long) private[spark] class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, @@ -49,21 +47,17 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, // Request the current storage status from the Master val future = master ? GetStorageStatus future.map { status => - val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] + val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray // Calculate macro-level statistics - val maxMem = storageStati.map(_.maxMem).reduce(_+_) - val remainingMem = storageStati.map(_.remainingMem).reduce(_+_) - val diskSpaceUsed = storageStati.flatMap(_.blocks.values.map(_.diskSize)) + val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) + val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) + val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)) .reduceOption(_+_).getOrElse(0L) - // Filter out everything that's not and rdd. - val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => - k.startsWith("rdd") - }.toMap - val rdds = rddInfoFromBlockStati(rddBlocks) + val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) - spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds.toList) + spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) } }}} ~ get { path("rdd") { parameter("id") { id => { completeWith { @@ -71,13 +65,13 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, future.map { status => val prefix = "rdd_" + id.toString - val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] - val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => - k.startsWith(prefix) - }.toMap - val rddInfo = rddInfoFromBlockStati(rddBlocks).first - spark.storage.html.rdd.render(rddInfo, rddBlocks) + val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray + val filteredStorageStatusList = StorageUtils.filterStorageStatusByPrefix(storageStatusList, prefix) + + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).first + + spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) } }}}}} ~ @@ -87,23 +81,6 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, } - private def rddInfoFromBlockStati(infos: Map[String, BlockStatus]) : Array[RDDInfo] = { - infos.groupBy { case(k,v) => - // Group by rdd name, ignore the partition name - k.substring(0,k.lastIndexOf('_')) - }.map { case(k,v) => - val blockStati = v.map(_._2).toArray - // Add up memory and disk sizes - val tmp = blockStati.map { x => (x.memSize, x.diskSize)}.reduce { (x,y) => - (x._1 + y._1, x._2 + y._2) - } - // Get the friendly name for the rdd, if available. - // This is pretty hacky, is there a better way? - val rddId = k.split("_").last.toInt - val rddName : String = Option(sc.rddNames.get(rddId)).getOrElse(k) - val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel - RDDInfo(rddId, rddName, rddStorageLevel, blockStati.length, tmp._1, tmp._2) - }.toArray - } + } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index c497f03e0c..97d8c7566d 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -68,6 +68,15 @@ class StorageLevel( override def toString: String = "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) + + def description : String = { + var result = "" + result += (if (useDisk) "Disk " else "") + result += (if (useMemory) "Memory " else "") + result += (if (deserialized) "Deserialized " else "Serialized") + result += "%sx Replicated".format(replication) + result + } } object StorageLevel { diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala new file mode 100644 index 0000000000..ebc7390ee5 --- /dev/null +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -0,0 +1,78 @@ +package spark.storage + +import spark.SparkContext + +private[spark] +case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, + blocks: Map[String, BlockStatus]) { + + def memUsed(blockPrefix: String = "") = { + blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize). + reduceOption(_+_).getOrElse(0l) + } + + def diskUsed(blockPrefix: String = "") = { + blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize). + reduceOption(_+_).getOrElse(0l) + } + + def memRemaining : Long = maxMem - memUsed() + +} + +case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, + numPartitions: Int, memSize: Long, diskSize: Long, locations: Array[BlockManagerId]) + + +/* Helper methods for storage-related objects */ +private[spark] +object StorageUtils { + + /* Given the current storage status of the BlockManager, returns information for each RDD */ + def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus], + sc: SparkContext) : Array[RDDInfo] = { + rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) + } + + /* Given a list of BlockStatus objets, returns information for each RDD */ + def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], + sc: SparkContext) : Array[RDDInfo] = { + // Find all RDD Blocks (ignore broadcast variables) + val rddBlocks = infos.filterKeys(_.startsWith("rdd")) + + // Group by rddId, ignore the partition name + val groupedRddBlocks = infos.groupBy { case(k, v) => + k.substring(0,k.lastIndexOf('_')) + }.mapValues(_.values.toArray) + + // For each RDD, generate an RDDInfo object + groupedRddBlocks.map { case(rddKey, rddBlocks) => + + // Add up memory and disk sizes + val memSize = rddBlocks.map(_.memSize).reduce(_ + _) + val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _) + + // Find the id of the RDD, e.g. rdd_1 => 1 + val rddId = rddKey.split("_").last.toInt + // Get the friendly name for the rdd, if available. + val rddName = Option(sc.persistentRdds.get(rddId).name).getOrElse(rddKey) + val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel + + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize, + rddBlocks.map(_.blockManagerId)) + }.toArray + } + + /* Removes all BlockStatus object that are not part of a block prefix */ + def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], + prefix: String) : Array[StorageStatus] = { + + storageStatusList.map { status => + val newBlocks = status.blocks.filterKeys(_.startsWith(prefix)) + //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _) + StorageStatus(status.blockManagerId, status.maxMem, newBlocks) + } + + } + +} \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/index.scala.html b/core/src/main/twirl/spark/storage/index.scala.html index fa7dad51ee..2b337f6133 100644 --- a/core/src/main/twirl/spark/storage/index.scala.html +++ b/core/src/main/twirl/spark/storage/index.scala.html @@ -1,4 +1,5 @@ -@(maxMem: Long, remainingMem: Long, diskSpaceUsed: Long, rdds: List[spark.storage.RDDInfo]) +@(maxMem: Long, remainingMem: Long, diskSpaceUsed: Long, rdds: Array[spark.storage.RDDInfo], storageStatusList: Array[spark.storage.StorageStatus]) +@import spark.Utils @spark.common.html.layout(title = "Storage Dashboard") { @@ -7,16 +8,16 @@
  • Memory: - @{spark.Utils.memoryBytesToString(maxMem - remainingMem)} Used - (@{spark.Utils.memoryBytesToString(remainingMem)} Available)
  • -
  • Disk: @{spark.Utils.memoryBytesToString(diskSpaceUsed)} Used
  • + @{Utils.memoryBytesToString(maxMem - remainingMem)} Used + (@{Utils.memoryBytesToString(remainingMem)} Available) +
  • Disk: @{Utils.memoryBytesToString(diskSpaceUsed)} Used

- +

RDD Summary

@@ -25,4 +26,15 @@
+
+ + +
+
+

Worker Summary

+
+ @worker_table(storageStatusList) +
+
+ } \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html index 075289c826..ac7f8c981f 100644 --- a/core/src/main/twirl/spark/storage/rdd.scala.html +++ b/core/src/main/twirl/spark/storage/rdd.scala.html @@ -1,4 +1,5 @@ -@(rddInfo: spark.storage.RDDInfo, blocks: Map[String, spark.storage.BlockStatus]) +@(rddInfo: spark.storage.RDDInfo, storageStatusList: Array[spark.storage.StorageStatus]) +@import spark.Utils @spark.common.html.layout(title = "RDD Info ") { @@ -8,21 +9,18 @@
  • Storage Level: - @(if (rddInfo.storageLevel.useDisk) "Disk" else "") - @(if (rddInfo.storageLevel.useMemory) "Memory" else "") - @(if (rddInfo.storageLevel.deserialized) "Deserialized" else "") - @(rddInfo.storageLevel.replication)x Replicated + @(rddInfo.storageLevel.description)
  • Partitions: @(rddInfo.numPartitions)
  • Memory Size: - @{spark.Utils.memoryBytesToString(rddInfo.memSize)} + @{Utils.memoryBytesToString(rddInfo.memSize)}
  • Disk Size: - @{spark.Utils.memoryBytesToString(rddInfo.diskSize)} + @{Utils.memoryBytesToString(rddInfo.diskSize)}
@@ -36,6 +34,7 @@

RDD Summary


+ @@ -47,17 +46,14 @@ - @blocks.map { case (k,v) => + @storageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1).map { case (k,v) => - - + + } @@ -67,4 +63,15 @@ +
+ + +
+
+

Worker Summary

+
+ @worker_table(storageStatusList, "rdd_" + rddInfo.id ) +
+
+ } \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd_row.scala.html b/core/src/main/twirl/spark/storage/rdd_row.scala.html deleted file mode 100644 index 3dd9944e3b..0000000000 --- a/core/src/main/twirl/spark/storage/rdd_row.scala.html +++ /dev/null @@ -1,18 +0,0 @@ -@(rdd: spark.storage.RDDInfo) - - - - - - - - \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd_table.scala.html b/core/src/main/twirl/spark/storage/rdd_table.scala.html index 24f55ccefb..af801cf229 100644 --- a/core/src/main/twirl/spark/storage/rdd_table.scala.html +++ b/core/src/main/twirl/spark/storage/rdd_table.scala.html @@ -1,4 +1,5 @@ -@(rdds: List[spark.storage.RDDInfo]) +@(rdds: Array[spark.storage.RDDInfo]) +@import spark.Utils
@k - @(if (v.storageLevel.useDisk) "Disk" else "") - @(if (v.storageLevel.useMemory) "Memory" else "") - @(if (v.storageLevel.deserialized) "Deserialized" else "") - @(v.storageLevel.replication)x Replicated + @(v.storageLevel.description) @{spark.Utils.memoryBytesToString(v.memSize)}@{spark.Utils.memoryBytesToString(v.diskSize)}@{Utils.memoryBytesToString(v.memSize)}@{Utils.memoryBytesToString(v.diskSize)}
- - @rdd.name - - - @(if (rdd.storageLevel.useDisk) "Disk" else "") - @(if (rdd.storageLevel.useMemory) "Memory" else "") - @(if (rdd.storageLevel.deserialized) "Deserialized" else "") - @(rdd.storageLevel.replication)x Replicated - @rdd.numPartitions@{spark.Utils.memoryBytesToString(rdd.memSize)}@{spark.Utils.memoryBytesToString(rdd.diskSize)}
@@ -12,7 +13,18 @@ @for(rdd <- rdds) { - @rdd_row(rdd) + + + + + + + }
+ + @rdd.name + + @(rdd.storageLevel.description) + @rdd.numPartitions@{Utils.memoryBytesToString(rdd.memSize)}@{Utils.memoryBytesToString(rdd.diskSize)}
\ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/worker_table.scala.html b/core/src/main/twirl/spark/storage/worker_table.scala.html new file mode 100644 index 0000000000..d54b8de4cc --- /dev/null +++ b/core/src/main/twirl/spark/storage/worker_table.scala.html @@ -0,0 +1,24 @@ +@(workersStatusList: Array[spark.storage.StorageStatus], prefix: String = "") +@import spark.Utils + + + + + + + + + + + @for(status <- workersStatusList) { + + + + + + } + +
HostMemory UsageDisk Usage
@(status.blockManagerId.ip + ":" + status.blockManagerId.port) + @(Utils.memoryBytesToString(status.memUsed(prefix))) + (@(Utils.memoryBytesToString(status.memRemaining)) Total Available) + @(Utils.memoryBytesToString(status.diskUsed(prefix)))
\ No newline at end of file -- cgit v1.2.3 From 8a25d530edfa3abcdbe2effcd6bfbe484ac40acb Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 13 Nov 2012 02:16:28 -0800 Subject: Optimized checkpoint writing by reusing FileSystem object. Fixed bug in updating of checkpoint data in DStream where the checkpointed RDDs, upon recovery, were not recognized as checkpointed RDDs and therefore deleted from HDFS. Made InputStreamsSuite more robust to timing delays. --- core/src/main/scala/spark/RDD.scala | 6 +- .../main/scala/spark/streaming/Checkpoint.scala | 73 ++++++++++-------- .../src/main/scala/spark/streaming/DStream.scala | 8 +- .../src/main/scala/spark/streaming/Scheduler.scala | 28 +++++-- .../scala/spark/streaming/StreamingContext.scala | 10 +-- streaming/src/test/resources/log4j.properties | 2 +- .../scala/spark/streaming/CheckpointSuite.scala | 25 +++--- .../scala/spark/streaming/InputStreamsSuite.scala | 88 ++++++++++++---------- 8 files changed, 129 insertions(+), 111 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 63048d5df0..6af8c377b5 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -189,11 +189,7 @@ abstract class RDD[T: ClassManifest]( def getCheckpointData(): Any = { synchronized { - if (isCheckpointed) { - checkpointFile - } else { - null - } + checkpointFile } } diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index a70fb8f73a..770f7b0cc0 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -5,7 +5,7 @@ import spark.{Logging, Utils} import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration -import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} +import java.io._ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) @@ -18,8 +18,6 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val checkpointDir = ssc.checkpointDir val checkpointInterval = ssc.checkpointInterval - validate() - def validate() { assert(master != null, "Checkpoint.master is null") assert(framework != null, "Checkpoint.framework is null") @@ -27,35 +25,50 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) assert(checkpointTime != null, "Checkpoint.checkpointTime is null") logInfo("Checkpoint for time " + checkpointTime + " validated") } +} - def save(path: String) { - val file = new Path(path, "graph") - val conf = new Configuration() - val fs = file.getFileSystem(conf) - logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") - if (fs.exists(file)) { - val bkFile = new Path(file.getParent, file.getName + ".bk") - FileUtil.copy(fs, file, fs, bkFile, true, true, conf) - logDebug("Moved existing checkpoint file to " + bkFile) +/** + * Convenience class to speed up the writing of graph checkpoint to file + */ +class CheckpointWriter(checkpointDir: String) extends Logging { + val file = new Path(checkpointDir, "graph") + val conf = new Configuration() + var fs = file.getFileSystem(conf) + val maxAttempts = 3 + + def write(checkpoint: Checkpoint) { + // TODO: maybe do this in a different thread from the main stream execution thread + var attempts = 0 + while (attempts < maxAttempts) { + attempts += 1 + try { + logDebug("Saving checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'") + if (fs.exists(file)) { + val bkFile = new Path(file.getParent, file.getName + ".bk") + FileUtil.copy(fs, file, fs, bkFile, true, true, conf) + logDebug("Moved existing checkpoint file to " + bkFile) + } + val fos = fs.create(file) + val oos = new ObjectOutputStream(fos) + oos.writeObject(checkpoint) + oos.close() + logInfo("Checkpoint for time " + checkpoint.checkpointTime + " saved to file '" + file + "'") + fos.close() + return + } catch { + case ioe: IOException => + logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe) + } } - val fos = fs.create(file) - val oos = new ObjectOutputStream(fos) - oos.writeObject(this) - oos.close() - fs.close() - logInfo("Checkpoint of streaming context for time " + checkpointTime + " saved successfully to file '" + file + "'") - } - - def toBytes(): Array[Byte] = { - val bytes = Utils.serialize(this) - bytes + logError("Could not write checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'") } } -object Checkpoint extends Logging { - def load(path: String): Checkpoint = { +object CheckpointReader extends Logging { + + def read(path: String): Checkpoint = { val fs = new Path(path).getFileSystem(new Configuration()) val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk")) @@ -82,17 +95,11 @@ object Checkpoint extends Logging { logError("Error loading checkpoint from file '" + file + "'", e) } } else { - logWarning("Could not load checkpoint from file '" + file + "' as it does not exist") + logWarning("Could not read checkpoint from file '" + file + "' as it does not exist") } }) - throw new Exception("Could not load checkpoint from path '" + path + "'") - } - - def fromBytes(bytes: Array[Byte]): Checkpoint = { - val cp = Utils.deserialize[Checkpoint](bytes) - cp.validate() - cp + throw new Exception("Could not read checkpoint from path '" + path + "'") } } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index abf132e45e..7e6f73dd7d 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -289,6 +289,7 @@ extends Serializable with Logging { */ protected[streaming] def updateCheckpointData(currentTime: Time) { logInfo("Updating checkpoint data for time " + currentTime) + // Get the checkpointed RDDs from the generated RDDs val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null) .map(x => (x._1, x._2.getCheckpointData())) @@ -334,8 +335,11 @@ extends Serializable with Logging { logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs") checkpointData.foreach { case(time, data) => { - logInfo("Restoring checkpointed RDD for time " + time + " from file") - generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString))) + logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'") + val rdd = ssc.sc.objectFile[T](data.toString) + // Set the checkpoint file name to identify this RDD as a checkpointed RDD by updateCheckpointData() + rdd.checkpointFile = data.toString + generatedRDDs += ((time, rdd)) } } dependencies.foreach(_.restoreCheckpointData()) diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index de0fb1f3ad..e2dca91179 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -16,8 +16,16 @@ extends Logging { initLogging() val graph = ssc.graph + val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) + + val checkpointWriter = if (ssc.checkpointInterval != null && ssc.checkpointDir != null) { + new CheckpointWriter(ssc.checkpointDir) + } else { + null + } + val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] val timer = new RecurringTimer(clock, ssc.graph.batchDuration, generateRDDs(_)) @@ -52,19 +60,23 @@ extends Logging { logInfo("Scheduler stopped") } - def generateRDDs(time: Time) { + private def generateRDDs(time: Time) { SparkEnv.set(ssc.env) logInfo("\n-----------------------------------------------------\n") - graph.generateRDDs(time).foreach(submitJob) - logInfo("Generated RDDs for time " + time) + graph.generateRDDs(time).foreach(jobManager.runJob) graph.forgetOldRDDs(time) - if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { - ssc.doCheckpoint(time) - } + doCheckpoint(time) + logInfo("Generated RDDs for time " + time) } - def submitJob(job: Job) { - jobManager.runJob(job) + private def doCheckpoint(time: Time) { + if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { + val startTime = System.currentTimeMillis() + ssc.graph.updateCheckpointData(time) + checkpointWriter.write(new Checkpoint(ssc, time)) + val stopTime = System.currentTimeMillis() + logInfo("Checkpointing the graph took " + (stopTime - startTime) + " ms") + } } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index ab6d6e8dea..ef6a05a392 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -28,7 +28,7 @@ final class StreamingContext ( def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = this(new SparkContext(master, frameworkName, sparkHome, jars), null) - def this(path: String) = this(null, Checkpoint.load(path)) + def this(path: String) = this(null, CheckpointReader.read(path)) def this(cp_ : Checkpoint) = this(null, cp_) @@ -225,14 +225,6 @@ final class StreamingContext ( case e: Exception => logWarning("Error while stopping", e) } } - - def doCheckpoint(currentTime: Time) { - val startTime = System.currentTimeMillis() - graph.updateCheckpointData(currentTime) - new Checkpoint(this, currentTime).save(checkpointDir) - val stopTime = System.currentTimeMillis() - logInfo("Checkpointing the graph took " + (stopTime - startTime) + " ms") - } } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 02fe16866e..33774b463d 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -1,5 +1,5 @@ # Set everything to be logged to the console -log4j.rootCategory=WARN, console +log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 0ad57e38b9..b3afedf39f 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -24,7 +24,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { override def framework = "CheckpointSuite" - override def batchDuration = Milliseconds(200) + override def batchDuration = Milliseconds(500) override def checkpointDir = "checkpoint" @@ -34,7 +34,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { test("basic stream+rdd recovery") { - assert(batchDuration === Milliseconds(200), "batchDuration for this test must be 1 second") + assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") assert(checkpointInterval === batchDuration, "checkpointInterval for this test much be same as batchDuration") System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") @@ -134,9 +134,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val operation = (st: DStream[String]) => { st.map(x => (x, 1)) .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) - .checkpoint(Seconds(2)) + .checkpoint(batchDuration * 2) } - testCheckpointedOperation(input, operation, output, 3) + testCheckpointedOperation(input, operation, output, 7) } test("updateStateByKey") { @@ -148,14 +148,18 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { } st.map(x => (x, 1)) .updateStateByKey[RichInt](updateFunc) - .checkpoint(Seconds(2)) + .checkpoint(batchDuration * 2) .map(t => (t._1, t._2.self)) } - testCheckpointedOperation(input, operation, output, 3) + testCheckpointedOperation(input, operation, output, 7) } - - + /** + * Tests a streaming operation under checkpointing, by restart the operation + * from checkpoint file and verifying whether the final output is correct. + * The output is assumed to have come from a reliable queue which an replay + * data as required. + */ def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], @@ -170,8 +174,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val initialNumExpectedOutputs = initialNumBatches val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs - // Do half the computation (half the number of batches), create checkpoint file and quit - + // Do the computation for initial number of batches, create checkpoint file and quit ssc = setupStreams[U, V](input, operation) val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) @@ -193,8 +196,6 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { * Advances the manual clock on the streaming scheduler by given number of batches. * It also wait for the expected amount of time for each batch. */ - - def runStreamsWithRealDelay(ssc: StreamingContext, numBatches: Long) { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.time) diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 0957748603..3e99440226 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -16,24 +16,36 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + val testPort = 9999 + var testServer: TestServer = null + var testDir: File = null + override def checkpointDir = "checkpoint" after { FileUtils.deleteDirectory(new File(checkpointDir)) + if (testServer != null) { + testServer.stop() + testServer = null + } + if (testDir != null && testDir.exists()) { + FileUtils.deleteDirectory(testDir) + testDir = null + } } test("network input stream") { // Start the server - val serverPort = 9999 - val server = new TestServer(9999) - server.start() + testServer = new TestServer(testPort) + testServer.start() // Set up the streaming context and input streams val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - val networkStream = ssc.networkTextStream("localhost", serverPort, StorageLevel.MEMORY_AND_DISK) + val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]] val outputStream = new TestOutputStream(networkStream, outputBuffer) + def output = outputBuffer.flatMap(x => x) ssc.registerOutputStream(outputStream) ssc.start() @@ -41,21 +53,15 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq(1, 2, 3, 4, 5) val expectedOutput = input.map(_.toString) + Thread.sleep(1000) for (i <- 0 until input.size) { - server.send(input(i).toString + "\n") + testServer.send(input(i).toString + "\n") Thread.sleep(500) clock.addToTime(batchDuration.milliseconds) } - val startTime = System.currentTimeMillis() - while (outputBuffer.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - logInfo("output.size = " + outputBuffer.size + ", expectedOutput.size = " + expectedOutput.size) - Thread.sleep(100) - } Thread.sleep(1000) - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") logInfo("Stopping server") - server.stop() + testServer.stop() logInfo("Stopping context") ssc.stop() @@ -69,24 +75,24 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("--------------------------------") - assert(outputBuffer.size === expectedOutput.size) - for (i <- 0 until outputBuffer.size) { - assert(outputBuffer(i).size === 1) - assert(outputBuffer(i).head === expectedOutput(i)) + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i) === expectedOutput(i)) } } test("network input stream with checkpoint") { // Start the server - val serverPort = 9999 - val server = new TestServer(9999) - server.start() + testServer = new TestServer(testPort) + testServer.start() // Set up the streaming context and input streams var ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) ssc.checkpoint(checkpointDir, checkpointInterval) - val networkStream = ssc.networkTextStream("localhost", serverPort, StorageLevel.MEMORY_AND_DISK) + val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) var outputStream = new TestOutputStream(networkStream, new ArrayBuffer[Seq[String]]) ssc.registerOutputStream(outputStream) ssc.start() @@ -94,7 +100,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Feed data to the server to send to the network receiver var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] for (i <- Seq(1, 2, 3)) { - server.send(i.toString + "\n") + testServer.send(i.toString + "\n") Thread.sleep(100) clock.addToTime(batchDuration.milliseconds) } @@ -109,7 +115,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { ssc.start() clock = ssc.scheduler.clock.asInstanceOf[ManualClock] for (i <- Seq(4, 5, 6)) { - server.send(i.toString + "\n") + testServer.send(i.toString + "\n") Thread.sleep(100) clock.addToTime(batchDuration.milliseconds) } @@ -120,12 +126,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("file input stream") { + // Create a temporary directory - val dir = { + testDir = { var temp = File.createTempFile(".temp.", Random.nextInt().toString) temp.delete() temp.mkdirs() - temp.deleteOnExit() logInfo("Created temp dir " + temp) temp } @@ -133,10 +139,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Set up the streaming context and input streams val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - val filestream = ssc.textFileStream(dir.toString) + val filestream = ssc.textFileStream(testDir.toString) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] def output = outputBuffer.flatMap(x => x) - val outputStream = new TestOutputStream(filestream, outputBuffer) ssc.registerOutputStream(outputStream) ssc.start() @@ -147,16 +152,16 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val expectedOutput = input.map(_.toString) Thread.sleep(1000) for (i <- 0 until input.size) { - FileUtils.writeStringToFile(new File(dir, i.toString), input(i).toString + "\n") - Thread.sleep(100) + FileUtils.writeStringToFile(new File(testDir, i.toString), input(i).toString + "\n") + Thread.sleep(500) clock.addToTime(batchDuration.milliseconds) - Thread.sleep(100) + //Thread.sleep(100) } val startTime = System.currentTimeMillis() - while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - //println("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size) + /*while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + logInfo("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size) Thread.sleep(100) - } + }*/ Thread.sleep(1000) val timeTaken = System.currentTimeMillis() - startTime assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") @@ -165,14 +170,16 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether data received by Spark Streaming was as expected logInfo("--------------------------------") - logInfo("output.size = " + output.size) + logInfo("output.size = " + outputBuffer.size) logInfo("output") - output.foreach(x => logInfo("[" + x.mkString(",") + "]")) + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("expected output.size = " + expectedOutput.size) logInfo("expected output") expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("--------------------------------") + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) assert(output.size === expectedOutput.size) for (i <- 0 until output.size) { assert(output(i).size === 1) @@ -182,12 +189,11 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { test("file input stream with checkpoint") { // Create a temporary directory - val dir = { + testDir = { var temp = File.createTempFile(".temp.", Random.nextInt().toString) temp.delete() temp.mkdirs() - temp.deleteOnExit() - println("Created temp dir " + temp) + logInfo("Created temp dir " + temp) temp } @@ -195,7 +201,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { var ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) ssc.checkpoint(checkpointDir, checkpointInterval) - val filestream = ssc.textFileStream(dir.toString) + val filestream = ssc.textFileStream(testDir.toString) var outputStream = new TestOutputStream(filestream, new ArrayBuffer[Seq[String]]) ssc.registerOutputStream(outputStream) ssc.start() @@ -204,7 +210,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] Thread.sleep(1000) for (i <- Seq(1, 2, 3)) { - FileUtils.writeStringToFile(new File(dir, i.toString), i.toString + "\n") + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") Thread.sleep(100) clock.addToTime(batchDuration.milliseconds) } @@ -221,7 +227,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { clock = ssc.scheduler.clock.asInstanceOf[ManualClock] Thread.sleep(500) for (i <- Seq(4, 5, 6)) { - FileUtils.writeStringToFile(new File(dir, i.toString), i.toString + "\n") + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") Thread.sleep(100) clock.addToTime(batchDuration.milliseconds) } -- cgit v1.2.3 From 10c1abcb6ac42b248818fa585a9ad49c2fa4851a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 17 Nov 2012 17:27:00 -0800 Subject: Fixed checkpointing bug in CoGroupedRDD. CoGroupSplits kept around the RDD splits of its parent RDDs, thus checkpointing its parents did not release the references to the parent splits. --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 18 ++++++++++---- core/src/test/scala/spark/CheckpointSuite.scala | 28 ++++++++++++++++++++++ .../src/main/scala/spark/streaming/DStream.scala | 4 ++-- .../main/scala/spark/streaming/DStreamGraph.scala | 2 +- 4 files changed, 45 insertions(+), 7 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index a313ebcbe8..94ef1b56e8 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -12,9 +12,20 @@ import spark.RDD import spark.ShuffleDependency import spark.SparkEnv import spark.Split +import java.io.{ObjectOutputStream, IOException} private[spark] sealed trait CoGroupSplitDep extends Serializable -private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep +private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, var split: Split = null) + extends CoGroupSplitDep { + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + rdd.synchronized { + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() + } + } +} private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep private[spark] @@ -55,7 +66,6 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) @transient var splits_ : Array[Split] = { - val firstRdd = rdds.head val array = new Array[Split](part.numPartitions) for (i <- 0 until array.size) { array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) => @@ -63,7 +73,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) case s: ShuffleDependency[_, _] => new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep case _ => - new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep + new NarrowCoGroupSplitDep(r, i): CoGroupSplitDep } }.toList) } @@ -82,7 +92,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) } for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, itsSplit) => { + case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => { // Read them from the parent for ((k, v) <- rdd.iterator(itsSplit)) { getSeq(k.asInstanceOf[K])(depNum) += v diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 57dc43ddac..8622ce92aa 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -6,6 +6,7 @@ import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD} import spark.SparkContext._ import storage.StorageLevel import java.util.concurrent.Semaphore +import collection.mutable.ArrayBuffer class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { initLogging() @@ -92,6 +93,33 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val rdd2 = sc.makeRDD(5 to 6, 4).map(x => (x % 2, 1)) testCheckpointing(rdd1 => rdd1.map(x => (x % 2, 1)).cogroup(rdd2)) testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2)) + + // Special test to make sure that the CoGroupSplit of CoGroupedRDD do not + // hold on to the splits of its parent RDDs, as the splits of parent RDDs + // may change while checkpointing. Rather the splits of parent RDDs must + // be fetched at the time of serialization to ensure the latest splits to + // be sent along with the task. + + val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) + + val ones = sc.parallelize(1 to 100, 1).map(x => (x,1)) + val reduced = ones.reduceByKey(_ + _) + val seqOfCogrouped = new ArrayBuffer[RDD[(Int, Int)]]() + seqOfCogrouped += reduced.cogroup(ones).mapValues[Int](add) + for(i <- 1 to 10) { + seqOfCogrouped += seqOfCogrouped.last.cogroup(ones).mapValues(add) + } + val finalCogrouped = seqOfCogrouped.last + val intermediateCogrouped = seqOfCogrouped(5) + + val bytesBeforeCheckpoint = Utils.serialize(finalCogrouped.splits) + intermediateCogrouped.checkpoint() + finalCogrouped.count() + sleep(intermediateCogrouped) + val bytesAfterCheckpoint = Utils.serialize(finalCogrouped.splits) + println("Before = " + bytesBeforeCheckpoint.size + ", after = " + bytesAfterCheckpoint.size) + assert(bytesAfterCheckpoint.size < bytesBeforeCheckpoint.size, + "CoGroupedSplits still holds on to the splits of its parent RDDs") } /** diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 76cdf8c464..13770aa8fd 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -226,11 +226,11 @@ extends Serializable with Logging { case Some(newRDD) => if (storageLevel != StorageLevel.NONE) { newRDD.persist(storageLevel) - logInfo("Persisting RDD for time " + time + " to " + storageLevel + " at time " + time) + logInfo("Persisting RDD " + newRDD.id + " for time " + time + " to " + storageLevel + " at time " + time) } if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { newRDD.checkpoint() - logInfo("Marking RDD " + newRDD + " for time " + time + " for checkpointing at time " + time) + logInfo("Marking RDD " + newRDD.id + " for time " + time + " for checkpointing at time " + time) } generatedRDDs.put(time, newRDD) Some(newRDD) diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index 246522838a..bd8c033eab 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -105,7 +105,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { private[streaming] def validate() { this.synchronized { assert(batchDuration != null, "Batch duration has not been set") - assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low") + //assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + " is very low") assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute") } } -- cgit v1.2.3 From c97ebf64377e853ab7c616a103869a4417f25954 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 19 Nov 2012 23:22:07 +0000 Subject: Fixed bug in the number of splits in RDD after checkpointing. Modified reduceByKeyAndWindow (naive) computation from window+reduceByKey to reduceByKey+window+reduceByKey. --- conf/streaming-env.sh.template | 2 +- core/src/main/scala/spark/RDD.scala | 3 ++- streaming/src/main/scala/spark/streaming/DStream.scala | 3 ++- streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala | 6 +++++- streaming/src/main/scala/spark/streaming/Scheduler.scala | 2 +- streaming/src/main/scala/spark/streaming/WindowedDStream.scala | 3 +++ 6 files changed, 14 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/conf/streaming-env.sh.template b/conf/streaming-env.sh.template index 6b4094c515..1ea9ba5541 100755 --- a/conf/streaming-env.sh.template +++ b/conf/streaming-env.sh.template @@ -11,7 +11,7 @@ SPARK_JAVA_OPTS+=" -XX:+UseConcMarkSweepGC" -# Using of Kryo serialization can improve serialization performance +# Using Kryo serialization can improve serialization performance # and therefore the throughput of the Spark Streaming programs. However, # using Kryo serialization with custom classes may required you to # register the classes with Kryo. Refer to the Spark documentation diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 6af8c377b5..8af6c9bd6a 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -222,12 +222,13 @@ abstract class RDD[T: ClassManifest]( rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString rdd.saveAsObjectFile(checkpointFile) rdd.synchronized { - rdd.checkpointRDD = context.objectFile[T](checkpointFile) + rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size) rdd.checkpointRDDSplits = rdd.checkpointRDD.splits rdd.changeDependencies(rdd.checkpointRDD) rdd.shouldCheckpoint = false rdd.isCheckpointInProgress = false rdd.isCheckpointed = true + println("Done checkpointing RDD " + rdd.id + ", " + rdd) } } } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 13770aa8fd..26d5ce9198 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -321,7 +321,8 @@ extends Serializable with Logging { } } } - logInfo("Updated checkpoint data for time " + currentTime) + logInfo("Updated checkpoint data for time " + currentTime + ", " + checkpointData.size + " checkpoints, " + + "[" + checkpointData.mkString(",") + "]") } /** diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index e09d27d34f..720e63bba0 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -4,6 +4,7 @@ import spark.streaming.StreamingContext._ import spark.{Manifests, RDD, Partitioner, HashPartitioner} import spark.SparkContext._ +import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer @@ -115,7 +116,10 @@ extends Serializable { slideTime: Time, partitioner: Partitioner ): DStream[(K, V)] = { - self.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), partitioner) + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + self.reduceByKey(cleanedReduceFunc, partitioner) + .window(windowTime, slideTime) + .reduceByKey(cleanedReduceFunc, partitioner) } // This method is the efficient sliding window reduce operation, diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index e2dca91179..014021be61 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -17,7 +17,7 @@ extends Logging { val graph = ssc.graph - val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt + val concurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) val checkpointWriter = if (ssc.checkpointInterval != null && ssc.checkpointDir != null) { diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala index ce89a3f99b..e4d2a634f5 100644 --- a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala @@ -2,6 +2,7 @@ package spark.streaming import spark.RDD import spark.rdd.UnionRDD +import spark.storage.StorageLevel class WindowedDStream[T: ClassManifest]( @@ -18,6 +19,8 @@ class WindowedDStream[T: ClassManifest]( throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + parent.persist(StorageLevel.MEMORY_ONLY_SER) + def windowTime: Time = _windowTime override def dependencies = List(parent) -- cgit v1.2.3 From b18d70870a33a4783c6b3b787bef9b0eec30bce0 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 27 Nov 2012 15:08:49 -0800 Subject: Modified bunch HashMaps in Spark to use TimeStampedHashMap and made various modules use CleanupTask to periodically clean up metadata. --- core/src/main/scala/spark/CacheTracker.scala | 6 +- core/src/main/scala/spark/MapOutputTracker.scala | 27 ++++--- .../main/scala/spark/scheduler/DAGScheduler.scala | 13 +++- .../scala/spark/scheduler/ShuffleMapTask.scala | 6 +- core/src/main/scala/spark/util/CleanupTask.scala | 31 ++++++++ .../main/scala/spark/util/TimeStampedHashMap.scala | 87 ++++++++++++++++++++++ .../scala/spark/streaming/StreamingContext.scala | 13 +++- 7 files changed, 165 insertions(+), 18 deletions(-) create mode 100644 core/src/main/scala/spark/util/CleanupTask.scala create mode 100644 core/src/main/scala/spark/util/TimeStampedHashMap.scala (limited to 'core') diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index c5db6ce63a..0ee59bee0f 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -14,6 +14,7 @@ import scala.collection.mutable.HashSet import spark.storage.BlockManager import spark.storage.StorageLevel +import util.{CleanupTask, TimeStampedHashMap} private[spark] sealed trait CacheTrackerMessage @@ -30,7 +31,7 @@ private[spark] case object StopCacheTracker extends CacheTrackerMessage private[spark] class CacheTrackerActor extends Actor with Logging { // TODO: Should probably store (String, CacheType) tuples - private val locs = new HashMap[Int, Array[List[String]]] + private val locs = new TimeStampedHashMap[Int, Array[List[String]]] /** * A map from the slave's host name to its cache size. @@ -38,6 +39,8 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] + private val cleanupTask = new CleanupTask("CacheTracker", locs.cleanup) + private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host) @@ -86,6 +89,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { case StopCacheTracker => logInfo("Stopping CacheTrackerActor") sender ! true + cleanupTask.cancel() context.stop(self) } } diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 45441aa5e5..d0be1bb913 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -17,6 +17,7 @@ import scala.collection.mutable.HashSet import scheduler.MapStatus import spark.storage.BlockManagerId import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import util.{CleanupTask, TimeStampedHashMap} private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) @@ -43,7 +44,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea val timeout = 10.seconds - var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]] + var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. @@ -52,7 +53,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea // Cache a serialized version of the output statuses for each shuffle to send them out faster var cacheGeneration = generation - val cachedSerializedStatuses = new HashMap[Int, Array[Byte]] + val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] var trackerActor: ActorRef = if (isMaster) { val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName) @@ -63,6 +64,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea actorSystem.actorFor(url) } + val cleanupTask = new CleanupTask("MapOutputTracker", this.cleanup) + // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. def askTracker(message: Any): Any = { @@ -83,14 +86,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.get(shuffleId) != null) { + if (mapStatuses.get(shuffleId) != None) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - var array = mapStatuses.get(shuffleId) + var array = mapStatuses(shuffleId) array.synchronized { array(mapId) = status } @@ -107,7 +110,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var array = mapStatuses.get(shuffleId) + var array = mapStatuses(shuffleId) if (array != null) { array.synchronized { if (array(mapId).address == bmAddress) { @@ -125,7 +128,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { - val statuses = mapStatuses.get(shuffleId) + val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") fetching.synchronized { @@ -138,7 +141,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea case e: InterruptedException => } } - return mapStatuses.get(shuffleId).map(status => + return mapStatuses(shuffleId).map(status => (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId)))) } else { fetching += shuffleId @@ -164,9 +167,15 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } } + def cleanup(cleanupTime: Long) { + mapStatuses.cleanup(cleanupTime) + cachedSerializedStatuses.cleanup(cleanupTime) + } + def stop() { communicate(StopMapOutputTracker) mapStatuses.clear() + cleanupTask.cancel() trackerActor = null } @@ -192,7 +201,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea generationLock.synchronized { if (newGen > generation) { logInfo("Updating generation to " + newGen + " and clearing cache") - mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]] + mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] generation = newGen } } @@ -210,7 +219,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea case Some(bytes) => return bytes case None => - statuses = mapStatuses.get(shuffleId) + statuses = mapStatuses(shuffleId) generationGotten = generation } } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index aaaed59c4a..3af877b817 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -14,6 +14,7 @@ import spark.partial.ApproximateEvaluator import spark.partial.PartialResult import spark.storage.BlockManagerMaster import spark.storage.BlockManagerId +import util.{CleanupTask, TimeStampedHashMap} /** * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for @@ -61,9 +62,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val nextStageId = new AtomicInteger(0) - val idToStage = new HashMap[Int, Stage] + val idToStage = new TimeStampedHashMap[Int, Stage] - val shuffleToMapStage = new HashMap[Int, Stage] + val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] var cacheLocs = new HashMap[Int, Array[List[String]]] @@ -83,6 +84,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] + val cleanupTask = new CleanupTask("DAGScheduler", this.cleanup) + // Start a thread to run the DAGScheduler event loop new Thread("DAGScheduler") { setDaemon(true) @@ -591,8 +594,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return Nil } + def cleanup(cleanupTime: Long) { + idToStage.cleanup(cleanupTime) + shuffleToMapStage.cleanup(cleanupTime) + } + def stop() { eventQueue.put(StopDAGScheduler) + cleanupTask.cancel() taskSched.stop() } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 60105c42b6..fbf618c906 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -14,17 +14,19 @@ import com.ning.compress.lzf.LZFOutputStream import spark._ import spark.storage._ +import util.{TimeStampedHashMap, CleanupTask} private[spark] object ShuffleMapTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new JHashMap[Int, Array[Byte]] + val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] + val cleanupTask = new CleanupTask("ShuffleMapTask", serializedInfoCache.cleanup) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { - val old = serializedInfoCache.get(stageId) + val old = serializedInfoCache.get(stageId).orNull if (old != null) { return old } else { diff --git a/core/src/main/scala/spark/util/CleanupTask.scala b/core/src/main/scala/spark/util/CleanupTask.scala new file mode 100644 index 0000000000..ccc28803e0 --- /dev/null +++ b/core/src/main/scala/spark/util/CleanupTask.scala @@ -0,0 +1,31 @@ +package spark.util + +import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} +import java.util.{TimerTask, Timer} +import spark.Logging + +class CleanupTask(name: String, cleanupFunc: (Long) => Unit) extends Logging { + val delayMins = System.getProperty("spark.cleanup.delay", "-100").toInt + val periodMins = System.getProperty("spark.cleanup.period", (delayMins / 10).toString).toInt + val timer = new Timer(name + " cleanup timer", true) + val task = new TimerTask { + def run() { + try { + if (delayMins > 0) { + + cleanupFunc(System.currentTimeMillis() - (delayMins * 60 * 1000)) + logInfo("Ran cleanup task for " + name) + } + } catch { + case e: Exception => logError("Error running cleanup task for " + name, e) + } + } + } + if (periodMins > 0) { + timer.schedule(task, periodMins * 60 * 1000, periodMins * 60 * 1000) + } + + def cancel() { + timer.cancel() + } +} diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala new file mode 100644 index 0000000000..7a22b80a20 --- /dev/null +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -0,0 +1,87 @@ +package spark.util + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{HashMap, Map} +import java.util.concurrent.ConcurrentHashMap + +/** + * This is a custom implementation of scala.collection.mutable.Map which stores the insertion + * time stamp along with each key-value pair. Key-value pairs that are older than a particular + * threshold time can them be removed using the cleanup method. This is intended to be a drop-in + * replacement of scala.collection.mutable.HashMap. + */ +class TimeStampedHashMap[A, B] extends Map[A, B]() { + val internalMap = new ConcurrentHashMap[A, (B, Long)]() + + def get(key: A): Option[B] = { + val value = internalMap.get(key) + if (value != null) Some(value._1) else None + } + + def iterator: Iterator[(A, B)] = { + val jIterator = internalMap.entrySet().iterator() + jIterator.map(kv => (kv.getKey, kv.getValue._1)) + } + + override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { + val newMap = new TimeStampedHashMap[A, B1] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.put(kv._1, (kv._2, currentTime)) + newMap + } + + override def - (key: A): Map[A, B] = { + internalMap.remove(key) + this + } + + override def += (kv: (A, B)): this.type = { + internalMap.put(kv._1, (kv._2, currentTime)) + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def update(key: A, value: B) { + this += ((key, value)) + } + + override def apply(key: A): B = { + val value = internalMap.get(key) + if (value == null) throw new NoSuchElementException() + value._1 + } + + override def filter(p: ((A, B)) => Boolean): Map[A, B] = { + internalMap.map(kv => (kv._1, kv._2._1)).filter(p) + } + + override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() + + override def size(): Int = internalMap.size() + + override def foreach[U](f: ((A, B)) => U): Unit = { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val kv = (entry.getKey, entry.getValue._1) + f(kv) + } + } + + def cleanup(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue._2 < threshTime) { + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() + +} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 4a41f2f516..58123dc82c 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -43,7 +43,7 @@ class StreamingContext private ( * @param batchDuration The time interval at which streaming data will be divided into batches */ def this(master: String, frameworkName: String, batchDuration: Time) = - this(new SparkContext(master, frameworkName), null, batchDuration) + this(StreamingContext.createNewSparkContext(master, frameworkName), null, batchDuration) /** * Recreates the StreamingContext from a checkpoint file. @@ -214,11 +214,8 @@ class StreamingContext private ( "Checkpoint directory has been set, but the graph checkpointing interval has " + "not been set. Please use StreamingContext.checkpoint() to set the interval." ) - - } - /** * This function starts the execution of the streams. */ @@ -265,6 +262,14 @@ class StreamingContext private ( object StreamingContext { + + def createNewSparkContext(master: String, frameworkName: String): SparkContext = { + if (System.getProperty("spark.cleanup.delay", "-1").toInt < 0) { + System.setProperty("spark.cleanup.delay", "60") + } + new SparkContext(master, frameworkName) + } + implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = { new PairDStreamFunctions[K, V](stream) } -- cgit v1.2.3 From d5e7aad039603a8a02d11f9ebda001422ca4c341 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 28 Nov 2012 08:36:55 +0000 Subject: Bug fixes --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 11 ++++++++++- core/src/main/scala/spark/util/CleanupTask.scala | 17 +++++++++-------- .../main/scala/spark/streaming/StreamingContext.scala | 2 +- 3 files changed, 20 insertions(+), 10 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 3af877b817..affacb43ca 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -78,7 +78,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures - val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage + val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits val activeJobs = new HashSet[ActiveJob] @@ -595,8 +595,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } def cleanup(cleanupTime: Long) { + var sizeBefore = idToStage.size idToStage.cleanup(cleanupTime) + logInfo("idToStage " + sizeBefore + " --> " + idToStage.size) + + sizeBefore = shuffleToMapStage.size shuffleToMapStage.cleanup(cleanupTime) + logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) + + sizeBefore = pendingTasks.size + pendingTasks.cleanup(cleanupTime) + logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) } def stop() { diff --git a/core/src/main/scala/spark/util/CleanupTask.scala b/core/src/main/scala/spark/util/CleanupTask.scala index ccc28803e0..a4357c62c6 100644 --- a/core/src/main/scala/spark/util/CleanupTask.scala +++ b/core/src/main/scala/spark/util/CleanupTask.scala @@ -5,24 +5,25 @@ import java.util.{TimerTask, Timer} import spark.Logging class CleanupTask(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delayMins = System.getProperty("spark.cleanup.delay", "-100").toInt - val periodMins = System.getProperty("spark.cleanup.period", (delayMins / 10).toString).toInt + val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt + val periodSeconds = math.max(10, delaySeconds / 10) val timer = new Timer(name + " cleanup timer", true) val task = new TimerTask { def run() { try { - if (delayMins > 0) { - - cleanupFunc(System.currentTimeMillis() - (delayMins * 60 * 1000)) + if (delaySeconds > 0) { + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) logInfo("Ran cleanup task for " + name) - } + } } catch { case e: Exception => logError("Error running cleanup task for " + name, e) } } } - if (periodMins > 0) { - timer.schedule(task, periodMins * 60 * 1000, periodMins * 60 * 1000) + if (periodSeconds > 0) { + logInfo("Starting cleanup task for " + name + " with delay of " + delaySeconds + " seconds and " + + "period of " + periodSeconds + " secs") + timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) } def cancel() { diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 58123dc82c..90dd560752 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -264,7 +264,7 @@ class StreamingContext private ( object StreamingContext { def createNewSparkContext(master: String, frameworkName: String): SparkContext = { - if (System.getProperty("spark.cleanup.delay", "-1").toInt < 0) { + if (System.getProperty("spark.cleanup.delay", "-1").toDouble < 0) { System.setProperty("spark.cleanup.delay", "60") } new SparkContext(master, frameworkName) -- cgit v1.2.3 From e463ae492068d2922e1d50c051a87f8010953dff Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 28 Nov 2012 14:05:01 -0800 Subject: Modified StorageLevel and BlockManagerId to cache common objects and use cached object while deserializing. --- .../main/scala/spark/storage/BlockManager.scala | 28 +------------ .../main/scala/spark/storage/BlockManagerId.scala | 48 ++++++++++++++++++++++ .../main/scala/spark/storage/StorageLevel.scala | 28 ++++++++++++- .../scala/spark/storage/BlockManagerSuite.scala | 26 ++++++++++++ 4 files changed, 101 insertions(+), 29 deletions(-) create mode 100644 core/src/main/scala/spark/storage/BlockManagerId.scala (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 70d6d8369d..e4aa9247a3 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -20,33 +20,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import sun.nio.ch.DirectBuffer -private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { - def this() = this(null, 0) // For deserialization only - - def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) - - override def writeExternal(out: ObjectOutput) { - out.writeUTF(ip) - out.writeInt(port) - } - - override def readExternal(in: ObjectInput) { - ip = in.readUTF() - port = in.readInt() - } - - override def toString = "BlockManagerId(" + ip + ", " + port + ")" - - override def hashCode = ip.hashCode * 41 + port - - override def equals(that: Any) = that match { - case id: BlockManagerId => port == id.port && ip == id.ip - case _ => false - } -} - - -private[spark] +private[spark] case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message) diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala new file mode 100644 index 0000000000..4933cc6606 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -0,0 +1,48 @@ +package spark.storage + +import java.io.{IOException, ObjectOutput, ObjectInput, Externalizable} +import java.util.concurrent.ConcurrentHashMap + +private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { + def this() = this(null, 0) // For deserialization only + + def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) + + override def writeExternal(out: ObjectOutput) { + out.writeUTF(ip) + out.writeInt(port) + } + + override def readExternal(in: ObjectInput) { + ip = in.readUTF() + port = in.readInt() + } + + @throws(classOf[IOException]) + private def readResolve(): Object = { + BlockManagerId.getCachedBlockManagerId(this) + } + + + override def toString = "BlockManagerId(" + ip + ", " + port + ")" + + override def hashCode = ip.hashCode * 41 + port + + override def equals(that: Any) = that match { + case id: BlockManagerId => port == id.port && ip == id.ip + case _ => false + } +} + +object BlockManagerId { + val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() + + def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { + if (blockManagerIdCache.containsKey(id)) { + blockManagerIdCache.get(id) + } else { + blockManagerIdCache.put(id, id) + id + } + } +} \ No newline at end of file diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index c497f03e0c..eb88eb2759 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -1,6 +1,9 @@ package spark.storage -import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput} +import collection.mutable +import util.Random +import collection.mutable.ArrayBuffer /** * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, @@ -17,7 +20,8 @@ class StorageLevel( extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - + assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + def this(flags: Int, replication: Int) { this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } @@ -27,6 +31,10 @@ class StorageLevel( override def clone(): StorageLevel = new StorageLevel( this.useDisk, this.useMemory, this.deserialized, this.replication) + override def hashCode(): Int = { + toInt * 41 + replication + } + override def equals(other: Any): Boolean = other match { case s: StorageLevel => s.useDisk == useDisk && @@ -66,6 +74,11 @@ class StorageLevel( replication = in.readByte() } + @throws(classOf[IOException]) + private def readResolve(): Object = { + StorageLevel.getCachedStorageLevel(this) + } + override def toString: String = "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) } @@ -82,4 +95,15 @@ object StorageLevel { val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2) val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) + + val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() + + def getCachedStorageLevel(level: StorageLevel): StorageLevel = { + if (storageLevelCache.containsKey(level)) { + storageLevelCache.get(level) + } else { + storageLevelCache.put(level, level) + level + } + } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 0e78228134..a2d5e39859 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -57,6 +57,32 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } } + test("StorageLevel object caching") { + val level1 = new StorageLevel(false, false, false, 3) + val level2 = new StorageLevel(false, false, false, 3) + val bytes1 = spark.Utils.serialize(level1) + val level1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val bytes2 = spark.Utils.serialize(level2) + val level2_ = spark.Utils.deserialize[StorageLevel](bytes2) + assert(level1_ === level1, "Deserialized level1 not same as original level1") + assert(level2_ === level2, "Deserialized level2 not same as original level1") + assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2") + assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1") + } + + test("BlockManagerId object caching") { + val id1 = new StorageLevel(false, false, false, 3) + val id2 = new StorageLevel(false, false, false, 3) + val bytes1 = spark.Utils.serialize(id1) + val id1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val bytes2 = spark.Utils.serialize(id2) + val id2_ = spark.Utils.deserialize[StorageLevel](bytes2) + assert(id1_ === id1, "Deserialized id1 not same as original id1") + assert(id2_ === id2, "Deserialized id2 not same as original id1") + assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2") + assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1") + } + test("master + 1 manager interaction") { store = new BlockManager(master, serializer, 2000) val a1 = new Array[Byte](400) -- cgit v1.2.3 From 9e9e9e1d898387a1996e4c57128bafadb5938a9b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 28 Nov 2012 18:48:14 -0800 Subject: Renamed CleanupTask to MetadataCleaner. --- core/src/main/scala/spark/CacheTracker.scala | 6 ++-- core/src/main/scala/spark/MapOutputTracker.scala | 6 ++-- .../main/scala/spark/scheduler/DAGScheduler.scala | 6 ++-- .../scala/spark/scheduler/ShuffleMapTask.scala | 5 ++-- core/src/main/scala/spark/util/CleanupTask.scala | 32 ---------------------- .../main/scala/spark/util/MetadataCleaner.scala | 32 ++++++++++++++++++++++ 6 files changed, 44 insertions(+), 43 deletions(-) delete mode 100644 core/src/main/scala/spark/util/CleanupTask.scala create mode 100644 core/src/main/scala/spark/util/MetadataCleaner.scala (limited to 'core') diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 0ee59bee0f..9888f061d9 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -14,7 +14,7 @@ import scala.collection.mutable.HashSet import spark.storage.BlockManager import spark.storage.StorageLevel -import util.{CleanupTask, TimeStampedHashMap} +import util.{MetadataCleaner, TimeStampedHashMap} private[spark] sealed trait CacheTrackerMessage @@ -39,7 +39,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] - private val cleanupTask = new CleanupTask("CacheTracker", locs.cleanup) + private val metadataCleaner = new MetadataCleaner("CacheTracker", locs.cleanup) private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) @@ -89,7 +89,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { case StopCacheTracker => logInfo("Stopping CacheTrackerActor") sender ! true - cleanupTask.cancel() + metadataCleaner.cancel() context.stop(self) } } diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index d0be1bb913..20ff5431af 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -17,7 +17,7 @@ import scala.collection.mutable.HashSet import scheduler.MapStatus import spark.storage.BlockManagerId import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import util.{CleanupTask, TimeStampedHashMap} +import util.{MetadataCleaner, TimeStampedHashMap} private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) @@ -64,7 +64,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea actorSystem.actorFor(url) } - val cleanupTask = new CleanupTask("MapOutputTracker", this.cleanup) + val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup) // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. @@ -175,7 +175,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea def stop() { communicate(StopMapOutputTracker) mapStatuses.clear() - cleanupTask.cancel() + metadataCleaner.cancel() trackerActor = null } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index affacb43ca..4b2570fa2b 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -14,7 +14,7 @@ import spark.partial.ApproximateEvaluator import spark.partial.PartialResult import spark.storage.BlockManagerMaster import spark.storage.BlockManagerId -import util.{CleanupTask, TimeStampedHashMap} +import util.{MetadataCleaner, TimeStampedHashMap} /** * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for @@ -84,7 +84,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] - val cleanupTask = new CleanupTask("DAGScheduler", this.cleanup) + val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) // Start a thread to run the DAGScheduler event loop new Thread("DAGScheduler") { @@ -610,7 +610,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def stop() { eventQueue.put(StopDAGScheduler) - cleanupTask.cancel() + metadataCleaner.cancel() taskSched.stop() } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index fbf618c906..683f5ebec3 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -14,7 +14,7 @@ import com.ning.compress.lzf.LZFOutputStream import spark._ import spark.storage._ -import util.{TimeStampedHashMap, CleanupTask} +import util.{TimeStampedHashMap, MetadataCleaner} private[spark] object ShuffleMapTask { @@ -22,7 +22,8 @@ private[spark] object ShuffleMapTask { // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val cleanupTask = new CleanupTask("ShuffleMapTask", serializedInfoCache.cleanup) + + val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.cleanup) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/spark/util/CleanupTask.scala b/core/src/main/scala/spark/util/CleanupTask.scala deleted file mode 100644 index a4357c62c6..0000000000 --- a/core/src/main/scala/spark/util/CleanupTask.scala +++ /dev/null @@ -1,32 +0,0 @@ -package spark.util - -import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} -import java.util.{TimerTask, Timer} -import spark.Logging - -class CleanupTask(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt - val periodSeconds = math.max(10, delaySeconds / 10) - val timer = new Timer(name + " cleanup timer", true) - val task = new TimerTask { - def run() { - try { - if (delaySeconds > 0) { - cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) - logInfo("Ran cleanup task for " + name) - } - } catch { - case e: Exception => logError("Error running cleanup task for " + name, e) - } - } - } - if (periodSeconds > 0) { - logInfo("Starting cleanup task for " + name + " with delay of " + delaySeconds + " seconds and " - + "period of " + periodSeconds + " secs") - timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) - } - - def cancel() { - timer.cancel() - } -} diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala new file mode 100644 index 0000000000..71ac39864e --- /dev/null +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -0,0 +1,32 @@ +package spark.util + +import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} +import java.util.{TimerTask, Timer} +import spark.Logging + +class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { + val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt + val periodSeconds = math.max(10, delaySeconds / 10) + val timer = new Timer(name + " cleanup timer", true) + val task = new TimerTask { + def run() { + try { + if (delaySeconds > 0) { + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) + logInfo("Ran metadata cleaner for " + name) + } + } catch { + case e: Exception => logError("Error running cleanup task for " + name, e) + } + } + } + if (periodSeconds > 0) { + logInfo("Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and " + + "period of " + periodSeconds + " secs") + timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) + } + + def cancel() { + timer.cancel() + } +} -- cgit v1.2.3 From c9789751bfc496d24e8369a0035d57f0ed8dcb58 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 28 Nov 2012 23:18:24 -0800 Subject: Added metadata cleaner to BlockManager to remove old blocks completely. --- .../main/scala/spark/storage/BlockManager.scala | 47 ++++++++++++++++------ .../scala/spark/storage/BlockManagerMaster.scala | 1 + 2 files changed, 36 insertions(+), 12 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index e4aa9247a3..1e36578e1a 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -10,12 +10,12 @@ import java.nio.{MappedByteBuffer, ByteBuffer} import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} -import scala.collection.JavaConversions._ import spark.{CacheTracker, Logging, SizeEstimator, SparkException, Utils} import spark.network._ import spark.serializer.Serializer -import spark.util.ByteBufferInputStream +import spark.util.{MetadataCleaner, TimeStampedHashMap, ByteBufferInputStream} + import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import sun.nio.ch.DirectBuffer @@ -51,7 +51,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - private val blockInfo = new ConcurrentHashMap[String, BlockInfo](1000) + private val blockInfo = new TimeStampedHashMap[String, BlockInfo]() private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private[storage] val diskStore: BlockStore = @@ -80,6 +80,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val host = System.getProperty("spark.hostname", Utils.localHostName()) + val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks) initialize() /** @@ -102,8 +103,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * Get storage level of local block. If no info exists for the block, then returns null. */ def getLevel(blockId: String): StorageLevel = { - val info = blockInfo.get(blockId) - if (info != null) info.level else null + blockInfo.get(blockId).map(_.level).orNull } /** @@ -113,9 +113,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def reportBlockStatus(blockId: String) { val (curLevel, inMemSize, onDiskSize) = blockInfo.get(blockId) match { - case null => + case None => (StorageLevel.NONE, 0L, 0L) - case info => + case Some(info) => info.synchronized { info.level match { case null => @@ -173,7 +173,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { info.waitForReady() // In case the block is still being put() by another thread @@ -258,7 +258,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { info.waitForReady() // In case the block is still being put() by another thread @@ -517,7 +517,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new IllegalArgumentException("Storage level is null or invalid") } - val oldBlock = blockInfo.get(blockId) + val oldBlock = blockInfo.get(blockId).orNull if (oldBlock != null) { logWarning("Block " + blockId + " already exists on this machine; not re-adding it") oldBlock.waitForReady() @@ -618,7 +618,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new IllegalArgumentException("Storage level is null or invalid") } - if (blockInfo.containsKey(blockId)) { + if (blockInfo.contains(blockId)) { logWarning("Block " + blockId + " already exists on this machine; not re-adding it") return } @@ -740,7 +740,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { logInfo("Dropping block " + blockId + " from memory") - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { val level = info.level @@ -767,6 +767,29 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } + def dropOldBlocks(cleanupTime: Long) { + logInfo("Dropping blocks older than " + cleanupTime) + val iterator = blockInfo.internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) + if (time < cleanupTime) { + info.synchronized { + val level = info.level + if (level.useMemory) { + memoryStore.remove(id) + } + if (level.useDisk) { + diskStore.remove(id) + } + iterator.remove() + logInfo("Dropped block " + id) + } + reportBlockStatus(id) + } + } + } + def shouldCompress(blockId: String): Boolean = { if (blockId.startsWith("shuffle_")) { compressShuffle diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 397395a65b..af15663621 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -341,6 +341,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor throw new Exception("Self index for " + blockManagerId + " not found") } + // Note that this logic will select the same node multiple times if there aren't enough peers var index = selfIndex while (res.size < size) { index += 1 -- cgit v1.2.3 From 6fcd09f499dca66d255aa7196839156433aae442 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 29 Nov 2012 02:06:33 -0800 Subject: Added TimeStampedHashSet and used that to cleanup the list of registered RDD IDs in CacheTracker. --- core/src/main/scala/spark/CacheTracker.scala | 10 +++- .../main/scala/spark/util/TimeStampedHashMap.scala | 14 +++-- .../main/scala/spark/util/TimeStampedHashSet.scala | 66 ++++++++++++++++++++++ 3 files changed, 81 insertions(+), 9 deletions(-) create mode 100644 core/src/main/scala/spark/util/TimeStampedHashSet.scala (limited to 'core') diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 9888f061d9..cb54e12257 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -14,7 +14,7 @@ import scala.collection.mutable.HashSet import spark.storage.BlockManager import spark.storage.StorageLevel -import util.{MetadataCleaner, TimeStampedHashMap} +import util.{TimeStampedHashSet, MetadataCleaner, TimeStampedHashMap} private[spark] sealed trait CacheTrackerMessage @@ -39,7 +39,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] - private val metadataCleaner = new MetadataCleaner("CacheTracker", locs.cleanup) + private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.cleanup) private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) @@ -113,11 +113,15 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b actorSystem.actorFor(url) } - val registeredRddIds = new HashSet[Int] + // TODO: Consider removing this HashSet completely as locs CacheTrackerActor already + // keeps track of registered RDDs + val registeredRddIds = new TimeStampedHashSet[Int] // Remembers which splits are currently being loaded (on worker nodes) val loading = new HashSet[String] + val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.cleanup) + // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. def askTracker(message: Any): Any = { diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index 7a22b80a20..9bcc9245c0 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -1,7 +1,7 @@ package spark.util -import scala.collection.JavaConversions._ -import scala.collection.mutable.{HashMap, Map} +import scala.collection.JavaConversions +import scala.collection.mutable.Map import java.util.concurrent.ConcurrentHashMap /** @@ -20,7 +20,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() { def iterator: Iterator[(A, B)] = { val jIterator = internalMap.entrySet().iterator() - jIterator.map(kv => (kv.getKey, kv.getValue._1)) + JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1)) } override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { @@ -31,8 +31,10 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() { } override def - (key: A): Map[A, B] = { - internalMap.remove(key) - this + val newMap = new TimeStampedHashMap[A, B] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.remove(key) + newMap } override def += (kv: (A, B)): this.type = { @@ -56,7 +58,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() { } override def filter(p: ((A, B)) => Boolean): Map[A, B] = { - internalMap.map(kv => (kv._1, kv._2._1)).filter(p) + JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) } override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() diff --git a/core/src/main/scala/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/spark/util/TimeStampedHashSet.scala new file mode 100644 index 0000000000..539dd75844 --- /dev/null +++ b/core/src/main/scala/spark/util/TimeStampedHashSet.scala @@ -0,0 +1,66 @@ +package spark.util + +import scala.collection.mutable.Set +import scala.collection.JavaConversions +import java.util.concurrent.ConcurrentHashMap + + +class TimeStampedHashSet[A] extends Set[A] { + val internalMap = new ConcurrentHashMap[A, Long]() + + def contains(key: A): Boolean = { + internalMap.contains(key) + } + + def iterator: Iterator[A] = { + val jIterator = internalMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).map(_.getKey) + } + + override def + (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet += elem + newSet + } + + override def - (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet -= elem + newSet + } + + override def += (key: A): this.type = { + internalMap.put(key, currentTime) + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def empty: Set[A] = new TimeStampedHashSet[A]() + + override def size(): Int = internalMap.size() + + override def foreach[U](f: (A) => U): Unit = { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + f(iterator.next.getKey) + } + } + + def cleanup(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue < threshTime) { + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() +} -- cgit v1.2.3 From 477de94894b7d8eeed281d33c12bcb2269d117c7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 1 Dec 2012 13:15:06 -0800 Subject: Minor modifications. --- core/src/main/scala/spark/util/MetadataCleaner.scala | 7 ++++++- streaming/src/main/scala/spark/streaming/DStream.scala | 15 ++++++++++++++- .../scala/spark/streaming/ReducedWindowedDStream.scala | 4 ++-- .../src/main/scala/spark/streaming/StreamingContext.scala | 8 ++++++-- 4 files changed, 28 insertions(+), 6 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 71ac39864e..2541b26255 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -5,7 +5,7 @@ import java.util.{TimerTask, Timer} import spark.Logging class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt + val delaySeconds = MetadataCleaner.getDelaySeconds val periodSeconds = math.max(10, delaySeconds / 10) val timer = new Timer(name + " cleanup timer", true) val task = new TimerTask { @@ -30,3 +30,8 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging timer.cancel() } } + +object MetadataCleaner { + def getDelaySeconds = (System.getProperty("spark.cleaner.delay", "-100").toDouble * 60).toInt + def setDelaySeconds(delay: Long) { System.setProperty("spark.cleaner.delay", delay.toString) } +} diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 8efda2074d..28a3e2dfc7 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -146,6 +146,8 @@ extends Serializable with Logging { } protected[streaming] def validate() { + assert(rememberDuration != null, "Remember duration is set to null") + assert( !mustCheckpoint || checkpointInterval != null, "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " + @@ -180,13 +182,24 @@ extends Serializable with Logging { checkpointInterval + "). Please set it to higher than " + checkpointInterval + "." ) + val metadataCleanupDelay = System.getProperty("spark.cleanup.delay", "-1").toDouble + assert( + metadataCleanupDelay < 0 || rememberDuration < metadataCleanupDelay * 60 * 1000, + "It seems you are doing some DStream window operation or setting a checkpoint interval " + + "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + + "than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" + + "delay is set to " + metadataCleanupDelay + " minutes, which is not sufficient. Please set " + + "the Java property 'spark.cleanup.delay' to more than " + + math.ceil(rememberDuration.millis.toDouble / 60000.0).toInt + " minutes." + ) + dependencies.foreach(_.validate()) logInfo("Slide time = " + slideTime) logInfo("Storage level = " + storageLevel) logInfo("Checkpoint interval = " + checkpointInterval) logInfo("Remember duration = " + rememberDuration) - logInfo("Initialized " + this) + logInfo("Initialized and validated " + this) } protected[streaming] def setContext(s: StreamingContext) { diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index bb852cbcca..f63a9e0011 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -118,8 +118,8 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( if (seqOfValues(0).isEmpty) { // If previous window's reduce value does not exist, then at least new values should exist if (newValues.isEmpty) { - val info = "seqOfValues =\n" + seqOfValues.map(x => "[" + x.mkString(",") + "]").mkString("\n") - throw new Exception("Neither previous window has value for key, nor new values found\n" + info) + throw new Exception("Neither previous window has value for key, nor new values found. " + + "Are you sure your key class hashes consistently?") } // Reduce the new values newValues.reduce(reduceF) // return diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 63d8766749..9c19f6588d 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -17,6 +17,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path import java.util.UUID +import spark.util.MetadataCleaner /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -268,8 +269,11 @@ class StreamingContext private ( object StreamingContext { def createNewSparkContext(master: String, frameworkName: String): SparkContext = { - if (System.getProperty("spark.cleanup.delay", "-1").toDouble < 0) { - System.setProperty("spark.cleanup.delay", "60") + + // Set the default cleaner delay to an hour if not already set. + // This should be sufficient for even 1 second interval. + if (MetadataCleaner.getDelaySeconds < 0) { + MetadataCleaner.setDelaySeconds(60) } new SparkContext(master, frameworkName) } -- cgit v1.2.3 From b4dba55f78b0dfda728cf69c9c17e4863010d28d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 2 Dec 2012 02:03:05 +0000 Subject: Made RDD checkpoint not create a new thread. Fixed bug in detecting when spark.cleaner.delay is insufficient. --- core/src/main/scala/spark/RDD.scala | 31 +++++++--------------- .../main/scala/spark/util/TimeStampedHashMap.scala | 3 ++- .../src/main/scala/spark/streaming/DStream.scala | 9 ++++--- 3 files changed, 17 insertions(+), 26 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 8af6c9bd6a..fbfcfbd704 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -211,28 +211,17 @@ abstract class RDD[T: ClassManifest]( if (startCheckpoint) { val rdd = this - val env = SparkEnv.get - - // Spawn a new thread to do the checkpoint as it takes sometime to write the RDD to file - val th = new Thread() { - override def run() { - // Save the RDD to a file, create a new HadoopRDD from it, - // and change the dependencies from the original parents to the new RDD - SparkEnv.set(env) - rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString - rdd.saveAsObjectFile(checkpointFile) - rdd.synchronized { - rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size) - rdd.checkpointRDDSplits = rdd.checkpointRDD.splits - rdd.changeDependencies(rdd.checkpointRDD) - rdd.shouldCheckpoint = false - rdd.isCheckpointInProgress = false - rdd.isCheckpointed = true - println("Done checkpointing RDD " + rdd.id + ", " + rdd) - } - } + rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString + rdd.saveAsObjectFile(checkpointFile) + rdd.synchronized { + rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size) + rdd.checkpointRDDSplits = rdd.checkpointRDD.splits + rdd.changeDependencies(rdd.checkpointRDD) + rdd.shouldCheckpoint = false + rdd.isCheckpointInProgress = false + rdd.isCheckpointed = true + println("Done checkpointing RDD " + rdd.id + ", " + rdd + ", created RDD " + rdd.checkpointRDD.id + ", " + rdd.checkpointRDD) } - th.start() } else { // Recursively call doCheckpoint() to perform checkpointing on parent RDD if they are marked dependencies.foreach(_.rdd.doCheckpoint()) diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index 9bcc9245c0..52f03784db 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -10,7 +10,7 @@ import java.util.concurrent.ConcurrentHashMap * threshold time can them be removed using the cleanup method. This is intended to be a drop-in * replacement of scala.collection.mutable.HashMap. */ -class TimeStampedHashMap[A, B] extends Map[A, B]() { +class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { val internalMap = new ConcurrentHashMap[A, (B, Long)]() def get(key: A): Option[B] = { @@ -79,6 +79,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() { while(iterator.hasNext) { val entry = iterator.next() if (entry.getValue._2 < threshTime) { + logDebug("Removing key " + entry.getKey) iterator.remove() } } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 28a3e2dfc7..d2e9de110e 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -182,14 +182,15 @@ extends Serializable with Logging { checkpointInterval + "). Please set it to higher than " + checkpointInterval + "." ) - val metadataCleanupDelay = System.getProperty("spark.cleanup.delay", "-1").toDouble + val metadataCleanerDelay = spark.util.MetadataCleaner.getDelaySeconds + logInfo("metadataCleanupDelay = " + metadataCleanerDelay) assert( - metadataCleanupDelay < 0 || rememberDuration < metadataCleanupDelay * 60 * 1000, + metadataCleanerDelay < 0 || rememberDuration < metadataCleanerDelay * 1000, "It seems you are doing some DStream window operation or setting a checkpoint interval " + "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + "than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" + - "delay is set to " + metadataCleanupDelay + " minutes, which is not sufficient. Please set " + - "the Java property 'spark.cleanup.delay' to more than " + + "delay is set to " + (metadataCleanerDelay / 60.0) + " minutes, which is not sufficient. Please set " + + "the Java property 'spark.cleaner.delay' to more than " + math.ceil(rememberDuration.millis.toDouble / 60000.0).toInt + " minutes." ) -- cgit v1.2.3 From a69a82be2682148f5d1ebbdede15a47c90eea73d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 3 Dec 2012 22:37:31 -0800 Subject: Added metadata cleaner to HttpBroadcast to clean up old broacast files. --- .../main/scala/spark/broadcast/HttpBroadcast.scala | 24 ++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 7eb4ddb74f..fef264aab1 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -11,6 +11,7 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import spark._ import spark.storage.StorageLevel +import util.{MetadataCleaner, TimeStampedHashSet} private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { @@ -64,6 +65,10 @@ private object HttpBroadcast extends Logging { private var serverUri: String = null private var server: HttpServer = null + private val files = new TimeStampedHashSet[String] + private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup) + + def initialize(isMaster: Boolean) { synchronized { if (!initialized) { @@ -85,6 +90,7 @@ private object HttpBroadcast extends Logging { server = null } initialized = false + cleaner.cancel() } } @@ -108,6 +114,7 @@ private object HttpBroadcast extends Logging { val serOut = ser.serializeStream(out) serOut.writeObject(value) serOut.close() + files += file.getAbsolutePath } def read[T](id: Long): T = { @@ -123,4 +130,21 @@ private object HttpBroadcast extends Logging { serIn.close() obj } + + def cleanup(cleanupTime: Long) { + val iterator = files.internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val (file, time) = (entry.getKey, entry.getValue) + if (time < cleanupTime) { + try { + iterator.remove() + new File(file.toString).delete() + logInfo("Deleted broadcast file '" + file + "'") + } catch { + case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) + } + } + } + } } -- cgit v1.2.3 From 21a08529768a5073bc5c15b6c2642ceef2acd0d5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 4 Dec 2012 22:10:25 -0800 Subject: Refactored RDD checkpointing to minimize extra fields in RDD class. --- core/src/main/scala/spark/RDD.scala | 149 ++++++++------------- core/src/main/scala/spark/RDDCheckpointData.scala | 68 ++++++++++ core/src/main/scala/spark/rdd/BlockRDD.scala | 9 +- core/src/main/scala/spark/rdd/CartesianRDD.scala | 10 +- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 2 +- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 2 +- core/src/main/scala/spark/rdd/HadoopRDD.scala | 2 - core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 2 - core/src/main/scala/spark/rdd/UnionRDD.scala | 12 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 2 +- core/src/test/scala/spark/CheckpointSuite.scala | 73 +--------- .../src/main/scala/spark/streaming/DStream.scala | 7 +- 12 files changed, 144 insertions(+), 194 deletions(-) create mode 100644 core/src/main/scala/spark/RDDCheckpointData.scala (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index fbfcfbd704..e9bd131e61 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -99,13 +99,7 @@ abstract class RDD[T: ClassManifest]( val partitioner: Option[Partitioner] = None /** Optionally overridden by subclasses to specify placement preferences. */ - def preferredLocations(split: Split): Seq[String] = { - if (isCheckpointed) { - checkpointRDD.preferredLocations(split) - } else { - Nil - } - } + def preferredLocations(split: Split): Seq[String] = Nil /** The [[spark.SparkContext]] that this RDD was created on. */ def context = sc @@ -118,6 +112,8 @@ abstract class RDD[T: ClassManifest]( // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE + protected[spark] val checkpointData = new RDDCheckpointData(this) + /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassManifest] = { dependencies.head.rdd.asInstanceOf[RDD[U]] @@ -126,17 +122,6 @@ abstract class RDD[T: ClassManifest]( /** Returns the `i` th parent RDD */ protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] - // Variables relating to checkpointing - protected val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD - - protected var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing - protected var isCheckpointInProgress = false // set to true when checkpointing is in progress - protected[spark] var isCheckpointed = false // set to true after checkpointing is completed - - protected[spark] var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed - protected var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file - protected var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD - // Methods available on all RDDs: /** @@ -162,83 +147,14 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - /** - * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` - * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. - * This is used to truncate very long lineages. In the current implementation, Spark will save - * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. - * Hence, it is strongly recommended to use checkpoint() on RDDs when - * (i) Checkpoint() is called before the any job has been executed on this RDD. - * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will - * require recomputation. - */ - protected[spark] def checkpoint() { - synchronized { - if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) { - // do nothing - } else if (isCheckpointable) { - if (sc.checkpointDir == null) { - throw new Exception("Checkpoint directory has not been set in the SparkContext.") - } - shouldCheckpoint = true - } else { - throw new Exception(this + " cannot be checkpointed") - } - } - } - - def getCheckpointData(): Any = { - synchronized { - checkpointFile - } - } - - /** - * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job - * using this RDD has completed (therefore the RDD has been materialized and - * potentially stored in memory). In case this RDD is not marked for checkpointing, - * doCheckpoint() is called recursively on the parent RDDs. - */ - private[spark] def doCheckpoint() { - val startCheckpoint = synchronized { - if (isCheckpointable && shouldCheckpoint && !isCheckpointInProgress) { - isCheckpointInProgress = true - true - } else { - false - } - } - - if (startCheckpoint) { - val rdd = this - rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString - rdd.saveAsObjectFile(checkpointFile) - rdd.synchronized { - rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size) - rdd.checkpointRDDSplits = rdd.checkpointRDD.splits - rdd.changeDependencies(rdd.checkpointRDD) - rdd.shouldCheckpoint = false - rdd.isCheckpointInProgress = false - rdd.isCheckpointed = true - println("Done checkpointing RDD " + rdd.id + ", " + rdd + ", created RDD " + rdd.checkpointRDD.id + ", " + rdd.checkpointRDD) - } + def getPreferredLocations(split: Split) = { + if (isCheckpointed) { + checkpointData.preferredLocations(split) } else { - // Recursively call doCheckpoint() to perform checkpointing on parent RDD if they are marked - dependencies.foreach(_.rdd.doCheckpoint()) + preferredLocations(split) } } - /** - * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]] - * (`newRDD`) created from the checkpoint file. This method must ensure that all references - * to the original parent RDDs must be removed to enable the parent RDDs to be garbage - * collected. Subclasses of RDD may override this method for implementing their own changing - * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. - */ - protected def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD)) - } - /** * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. * This should ''not'' be called by users directly, but is available for implementors of custom @@ -247,7 +163,7 @@ abstract class RDD[T: ClassManifest]( final def iterator(split: Split): Iterator[T] = { if (isCheckpointed) { // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original - checkpointRDD.iterator(checkpointRDDSplits(split.index)) + checkpointData.iterator(split.index) } else if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { @@ -589,6 +505,55 @@ abstract class RDD[T: ClassManifest]( sc.runJob(this, (iter: Iterator[T]) => iter.toArray) } + /** + * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` + * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. + * This is used to truncate very long lineages. In the current implementation, Spark will save + * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. + * Hence, it is strongly recommended to use checkpoint() on RDDs when + * (i) checkpoint() is called before the any job has been executed on this RDD. + * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will + * require recomputation. + */ + def checkpoint() { + checkpointData.markForCheckpoint() + } + + /** + * Return whether this RDD has been checkpointed or not + */ + def isCheckpointed(): Boolean = { + checkpointData.isCheckpointed() + } + + /** + * Gets the name of the file to which this RDD was checkpointed + */ + def getCheckpointFile(): Option[String] = { + checkpointData.getCheckpointFile() + } + + /** + * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler + * after a job using this RDD has completed (therefore the RDD has been materialized and + * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. + */ + protected[spark] def doCheckpoint() { + checkpointData.doCheckpoint() + dependencies.foreach(_.rdd.doCheckpoint()) + } + + /** + * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]] + * (`newRDD`) created from the checkpoint file. This method must ensure that all references + * to the original parent RDDs must be removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own changing + * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + */ + protected[spark] def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD)) + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { synchronized { diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala new file mode 100644 index 0000000000..eb4482acee --- /dev/null +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -0,0 +1,68 @@ +package spark + +import org.apache.hadoop.fs.Path + + + +private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) +extends Serializable { + + class CheckpointState extends Serializable { + var state = 0 + + def mark() { if (state == 0) state = 1 } + def start() { assert(state == 1); state = 2 } + def finish() { assert(state == 2); state = 3 } + + def isMarked() = { state == 1 } + def isInProgress = { state == 2 } + def isCheckpointed = { state == 3 } + } + + val cpState = new CheckpointState() + var cpFile: Option[String] = None + var cpRDD: Option[RDD[T]] = None + var cpRDDSplits: Seq[Split] = Nil + + def markForCheckpoint() = { + rdd.synchronized { cpState.mark() } + } + + def isCheckpointed() = { + rdd.synchronized { cpState.isCheckpointed } + } + + def getCheckpointFile() = { + rdd.synchronized { cpFile } + } + + def doCheckpoint() { + rdd.synchronized { + if (cpState.isMarked && !cpState.isInProgress) { + cpState.start() + } else { + return + } + } + + val file = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString + rdd.saveAsObjectFile(file) + val newRDD = rdd.context.objectFile[T](file, rdd.splits.size) + + rdd.synchronized { + rdd.changeDependencies(newRDD) + cpFile = Some(file) + cpRDD = Some(newRDD) + cpRDDSplits = newRDD.splits + cpState.finish() + } + } + + def preferredLocations(split: Split) = { + cpRDD.get.preferredLocations(split) + } + + def iterator(splitIndex: Int): Iterator[T] = { + cpRDD.get.iterator(cpRDDSplits(splitIndex)) + } +} diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index f4c3f99011..590f9eb738 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -41,12 +41,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St } } - override def preferredLocations(split: Split) = { - if (isCheckpointed) { - checkpointRDD.preferredLocations(split) - } else { - locations_(split.asInstanceOf[BlockRDDSplit].blockId) - } - } + override def preferredLocations(split: Split) = + locations_(split.asInstanceOf[BlockRDDSplit].blockId) } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 458ad38d55..9bfc3f8ca3 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -32,12 +32,8 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( override def splits = splits_ override def preferredLocations(split: Split) = { - if (isCheckpointed) { - checkpointRDD.preferredLocations(split) - } else { - val currSplit = split.asInstanceOf[CartesianSplit] - rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) - } + val currSplit = split.asInstanceOf[CartesianSplit] + rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) } override def compute(split: Split) = { @@ -56,7 +52,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( override def dependencies = deps_ - override protected def changeDependencies(newRDD: RDD[_]) { + override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) splits_ = newRDD.splits rdd1 = null diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 94ef1b56e8..adfecea966 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -112,7 +112,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) map.iterator } - override protected def changeDependencies(newRDD: RDD[_]) { + override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) splits_ = newRDD.splits rdds = null diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 5b5f72ddeb..90c3b8bfd8 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -48,7 +48,7 @@ class CoalescedRDD[T: ClassManifest]( override def dependencies = deps_ - override protected def changeDependencies(newRDD: RDD[_]) { + override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD)) splits_ = newRDD.splits prev = null diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index 19ed56d9c0..a12531ea89 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -115,6 +115,4 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } - - override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index 2875abb2db..c12df5839e 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -93,6 +93,4 @@ class NewHadoopRDD[K, V]( val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } - - override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 643a174160..30eb8483b6 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -49,15 +49,11 @@ class UnionRDD[T: ClassManifest]( override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() - override def preferredLocations(s: Split): Seq[String] = { - if (isCheckpointed) { - checkpointRDD.preferredLocations(s) - } else { - s.asInstanceOf[UnionSplit[T]].preferredLocations() - } - } + override def preferredLocations(s: Split): Seq[String] = + s.asInstanceOf[UnionSplit[T]].preferredLocations() + - override protected def changeDependencies(newRDD: RDD[_]) { + override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD)) splits_ = newRDD.splits rdds = null diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 4b2570fa2b..33d35b35d1 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -575,7 +575,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those - val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList + val rddPrefs = rdd.getPreferredLocations(rdd.splits(partition)).toList if (rddPrefs != Nil) { return rddPrefs } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 8622ce92aa..2cafef444c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -41,7 +41,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(parCollection.dependencies === Nil) val result = parCollection.collect() sleep(parCollection) // slightly extra time as loading classes for the first can take some time - assert(sc.objectFile[Int](parCollection.checkpointFile).collect() === result) + assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.collect() === result) } @@ -54,7 +54,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { blockRDD.checkpoint() val result = blockRDD.collect() sleep(blockRDD) - assert(sc.objectFile[String](blockRDD.checkpointFile).collect() === result) + assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.collect() === result) } @@ -122,35 +122,6 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { "CoGroupedSplits still holds on to the splits of its parent RDDs") } - /** - * This test forces two ResultTasks of the same job to be launched before and after - * the checkpointing of job's RDD is completed. - */ - test("Threading - ResultTasks") { - val op1 = (parCollection: RDD[Int]) => { - parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) - } - val op2 = (firstRDD: RDD[(Int, Int)]) => { - firstRDD.map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) - } - testThreading(op1, op2) - } - - /** - * This test forces two ShuffleMapTasks of the same job to be launched before and after - * the checkpointing of job's RDD is completed. - */ - test("Threading - ShuffleMapTasks") { - val op1 = (parCollection: RDD[Int]) => { - parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) - } - val op2 = (firstRDD: RDD[(Int, Int)]) => { - firstRDD.groupByKey(2).map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) - } - testThreading(op1, op2) - } - - def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { val parCollection = sc.makeRDD(1 to 4, 4) val operatedRDD = op(parCollection) @@ -159,49 +130,11 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val result = operatedRDD.collect() sleep(operatedRDD) //println(parentRDD + ", " + operatedRDD.dependencies.head.rdd ) - assert(sc.objectFile[U](operatedRDD.checkpointFile).collect() === result) + assert(sc.objectFile[U](operatedRDD.getCheckpointFile.get).collect() === result) assert(operatedRDD.dependencies.head.rdd != parentRDD) assert(operatedRDD.collect() === result) } - def testThreading[U: ClassManifest, V: ClassManifest](op1: (RDD[Int]) => RDD[U], op2: (RDD[U]) => RDD[V]) { - - val parCollection = sc.makeRDD(1 to 2, 2) - - // This is the RDD that is to be checkpointed - val firstRDD = op1(parCollection) - val parentRDD = firstRDD.dependencies.head.rdd - firstRDD.checkpoint() - - // This the RDD that uses firstRDD. This is designed to launch a - // ShuffleMapTask that uses firstRDD. - val secondRDD = op2(firstRDD) - - // Starting first job, to initiate the checkpointing - logInfo("\nLaunching 1st job to initiate checkpointing\n") - firstRDD.collect() - - // Checkpointing has started but not completed yet - Thread.sleep(100) - assert(firstRDD.dependencies.head.rdd === parentRDD) - - // Starting second job; first task of this job will be - // launched _before_ firstRDD is marked as checkpointed - // and the second task will be launched _after_ firstRDD - // is marked as checkpointed - logInfo("\nLaunching 2nd job that is designed to launch tasks " + - "before and after checkpointing is complete\n") - val result = secondRDD.collect() - - // Check whether firstRDD has been successfully checkpointed - assert(firstRDD.dependencies.head.rdd != parentRDD) - - logInfo("\nRecomputing 2nd job to verify the results of the previous computation\n") - // Check whether the result in the previous job was correct or not - val correctResult = secondRDD.collect() - assert(result === correctResult) - } - def sleep(rdd: RDD[_]) { val startTime = System.currentTimeMillis() val maxWaitTime = 5000 diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index d2e9de110e..d290c5927e 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -325,8 +325,9 @@ extends Serializable with Logging { logInfo("Updating checkpoint data for time " + currentTime) // Get the checkpointed RDDs from the generated RDDs - val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null) - .map(x => (x._1, x._2.getCheckpointData())) + + val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointFile.isDefined) + .map(x => (x._1, x._2.getCheckpointFile.get)) // Make a copy of the existing checkpoint data val oldCheckpointData = checkpointData.clone() @@ -373,7 +374,7 @@ extends Serializable with Logging { logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'") val rdd = ssc.sc.objectFile[T](data.toString) // Set the checkpoint file name to identify this RDD as a checkpointed RDD by updateCheckpointData() - rdd.checkpointFile = data.toString + rdd.checkpointData.cpFile = Some(data.toString) generatedRDDs += ((time, rdd)) } } -- cgit v1.2.3 From 1f3a75ae9e518c003d84fa38a54583ecd841ffdc Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 7 Dec 2012 13:45:52 -0800 Subject: Modified checkpoint testsuite to more comprehensively test checkpointing of various RDDs. Fixed checkpoint bug (splits referring to parent RDDs or parent splits) in UnionRDD and CoalescedRDD. Fixed bug in testing ShuffledRDD. Removed unnecessary and useless map-side combining step for narrow dependencies in CoGroupedRDD. Removed unncessary WeakReference stuff from many other RDDs. --- core/src/main/scala/spark/rdd/CartesianRDD.scala | 1 - core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 9 +- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 25 +- core/src/main/scala/spark/rdd/FilteredRDD.scala | 6 +- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 6 +- core/src/main/scala/spark/rdd/GlommedRDD.scala | 6 +- .../main/scala/spark/rdd/MapPartitionsRDD.scala | 6 +- .../spark/rdd/MapPartitionsWithSplitRDD.scala | 6 +- core/src/main/scala/spark/rdd/MappedRDD.scala | 6 +- core/src/main/scala/spark/rdd/PipedRDD.scala | 10 +- core/src/main/scala/spark/rdd/SampledRDD.scala | 12 +- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 16 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 19 +- core/src/test/scala/spark/CheckpointSuite.scala | 267 +++++++++++++++++---- 14 files changed, 285 insertions(+), 110 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 9bfc3f8ca3..1d753a5168 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,7 +1,6 @@ package spark.rdd import spark._ -import java.lang.ref.WeakReference private[spark] class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index adfecea966..57d472666b 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -17,6 +17,7 @@ import java.io.{ObjectOutputStream, IOException} private[spark] sealed trait CoGroupSplitDep extends Serializable private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, var split: Split = null) extends CoGroupSplitDep { + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { rdd.synchronized { @@ -50,12 +51,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) var deps_ = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { - val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) - if (mapSideCombinedRDD.partitioner == Some(part)) { - logInfo("Adding one-to-one dependency with " + mapSideCombinedRDD) - deps += new OneToOneDependency(mapSideCombinedRDD) + if (rdd.partitioner == Some(part)) { + logInfo("Adding one-to-one dependency with " + rdd) + deps += new OneToOneDependency(rdd) } else { logInfo("Adding shuffle dependency with " + rdd) + val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part) } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 90c3b8bfd8..0b4499e2eb 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -1,9 +1,24 @@ package spark.rdd import spark._ -import java.lang.ref.WeakReference +import java.io.{ObjectOutputStream, IOException} -private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split +private[spark] case class CoalescedRDDSplit( + index: Int, + @transient rdd: RDD[_], + parentsIndices: Array[Int] + ) extends Split { + var parents: Seq[Split] = parentsIndices.map(rdd.splits(_)) + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + rdd.synchronized { + // Update the reference to parent split at the time of task serialization + parents = parentsIndices.map(rdd.splits(_)) + oos.defaultWriteObject() + } + } +} /** * Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of @@ -21,12 +36,12 @@ class CoalescedRDD[T: ClassManifest]( @transient var splits_ : Array[Split] = { val prevSplits = prev.splits if (prevSplits.length < maxPartitions) { - prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) } + prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) } } else { (0 until maxPartitions).map { i => val rangeStart = (i * prevSplits.length) / maxPartitions val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions - new CoalescedRDDSplit(i, prevSplits.slice(rangeStart, rangeEnd)) + new CoalescedRDDSplit(i, prev, (rangeStart until rangeEnd).toArray) }.toArray } } @@ -42,7 +57,7 @@ class CoalescedRDD[T: ClassManifest]( var deps_ : List[Dependency[_]] = List( new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = - splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index) + splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices } ) diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index 1370cf6faf..02f2e7c246 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -1,15 +1,13 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] class FilteredRDD[T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: T => Boolean) - extends RDD[T](prev.get) { + extends RDD[T](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).filter(f) diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 6b2cc67568..cdc8ecdcfe 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -1,15 +1,13 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: T => TraversableOnce[U]) - extends RDD[U](prev.get) { + extends RDD[U](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).flatMap(f) diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index 0f0b6ab0ff..df6f61c69d 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -1,13 +1,11 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] -class GlommedRDD[T: ClassManifest](prev: WeakReference[RDD[T]]) - extends RDD[Array[T]](prev.get) { +class GlommedRDD[T: ClassManifest](prev: RDD[T]) + extends RDD[Array[T]](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index b04f56cfcc..23b9fb023b 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -1,16 +1,14 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false) - extends RDD[U](prev.get) { + extends RDD[U](prev) { override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index 7a4b6ffb03..41955c1d7a 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -1,9 +1,7 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference /** * A variant of the MapPartitionsRDD that passes the split index into the @@ -12,9 +10,9 @@ import java.lang.ref.WeakReference */ private[spark] class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: (Int, Iterator[T]) => Iterator[U]) - extends RDD[U](prev.get) { + extends RDD[U](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = f(split.index, firstParent[T].iterator(split)) diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 8fa1872e0a..6f8cb21fd3 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -1,15 +1,13 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: T => U) - extends RDD[U](prev.get) { + extends RDD[U](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).map(f) diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index d9293a9d1a..d2047375ea 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -8,11 +8,9 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source -import spark.OneToOneDependency import spark.RDD import spark.SparkEnv import spark.Split -import java.lang.ref.WeakReference /** @@ -20,16 +18,16 @@ import java.lang.ref.WeakReference * (printing them one per line) and returns the output as a collection of strings. */ class PipedRDD[T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], command: Seq[String], envVars: Map[String, String]) - extends RDD[String](prev.get) { + extends RDD[String](prev) { - def this(prev: WeakReference[RDD[T]], command: Seq[String]) = this(prev, command, Map()) + def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(prev: WeakReference[RDD[T]], command: String) = this(prev, PipedRDD.tokenize(command)) + def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) override def splits = firstParent[T].splits diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index f273f257f8..c622e14a66 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -7,7 +7,6 @@ import cern.jet.random.engine.DRand import spark.RDD import spark.OneToOneDependency import spark.Split -import java.lang.ref.WeakReference private[spark] class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { @@ -15,14 +14,14 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali } class SampledRDD[T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], withReplacement: Boolean, frac: Double, seed: Int) - extends RDD[T](prev.get) { + extends RDD[T](prev) { @transient - val splits_ = { + var splits_ : Array[Split] = { val rg = new Random(seed) firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } @@ -51,4 +50,9 @@ class SampledRDD[T: ClassManifest]( firstParent[T].iterator(split.prev).filter(x => (rand.nextDouble <= frac)) } } + + override def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 31774585f4..a9dd3f35ed 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -1,11 +1,7 @@ package spark.rdd -import spark.Partitioner -import spark.RDD -import spark.ShuffleDependency -import spark.SparkEnv -import spark.Split -import java.lang.ref.WeakReference +import spark._ +import scala.Some private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { override val index = idx @@ -14,15 +10,15 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { /** * The resulting RDD from a shuffle (e.g. repartitioning of data). - * @param parent the parent RDD. + * @param prev the parent RDD. * @param part the partitioner used to partition the RDD * @tparam K the key class. * @tparam V the value class. */ class ShuffledRDD[K, V]( - @transient prev: WeakReference[RDD[(K, V)]], + prev: RDD[(K, V)], part: Partitioner) - extends RDD[(K, V)](prev.get.context, List(new ShuffleDependency(prev.get, part))) { + extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) { override val partitioner = Some(part) @@ -37,7 +33,7 @@ class ShuffledRDD[K, V]( } override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = Nil + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) splits_ = newRDD.splits } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 30eb8483b6..a5948dd1f1 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -3,18 +3,28 @@ package spark.rdd import scala.collection.mutable.ArrayBuffer import spark._ -import java.lang.ref.WeakReference +import java.io.{ObjectOutputStream, IOException} private[spark] class UnionSplit[T: ClassManifest]( - idx: Int, + idx: Int, rdd: RDD[T], - split: Split) + splitIndex: Int, + var split: Split = null) extends Split with Serializable { def iterator() = rdd.iterator(split) def preferredLocations() = rdd.preferredLocations(split) override val index: Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + rdd.synchronized { + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() + } + } } class UnionRDD[T: ClassManifest]( @@ -27,7 +37,7 @@ class UnionRDD[T: ClassManifest]( val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { - array(pos) = new UnionSplit(pos, rdd, split) + array(pos) = new UnionSplit(pos, rdd, split.index) pos += 1 } array @@ -52,7 +62,6 @@ class UnionRDD[T: ClassManifest]( override def preferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() - override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD)) splits_ = newRDD.splits diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 2cafef444c..51bd59e2b1 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -2,17 +2,16 @@ package spark import org.scalatest.{BeforeAndAfter, FunSuite} import java.io.File -import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD} +import spark.rdd._ import spark.SparkContext._ import storage.StorageLevel -import java.util.concurrent.Semaphore -import collection.mutable.ArrayBuffer class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { initLogging() var sc: SparkContext = _ var checkpointDir: File = _ + val partitioner = new HashPartitioner(2) before { checkpointDir = File.createTempFile("temp", "") @@ -40,7 +39,6 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { parCollection.checkpoint() assert(parCollection.dependencies === Nil) val result = parCollection.collect() - sleep(parCollection) // slightly extra time as loading classes for the first can take some time assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.collect() === result) @@ -53,7 +51,6 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val blockRDD = new BlockRDD[String](sc, Array(blockId)) blockRDD.checkpoint() val result = blockRDD.collect() - sleep(blockRDD) assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.collect() === result) @@ -68,79 +65,247 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { testCheckpointing(_.mapPartitions(_.map(_.toString))) testCheckpointing(r => new MapPartitionsWithSplitRDD(r, (i: Int, iter: Iterator[Int]) => iter.map(_.toString) )) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), 1000) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x), 1000) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) testCheckpointing(_.pipe(Seq("cat"))) } test("ShuffledRDD") { - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _)) + // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD + testCheckpointing(rdd => { + new ShuffledRDD(rdd.map(x => (x % 2, 1)), partitioner) + }) } test("UnionRDD") { - testCheckpointing(_.union(sc.makeRDD(5 to 6, 4))) + def otherRDD = sc.makeRDD(1 to 10, 4) + testCheckpointing(_.union(otherRDD), false, true) + testParentCheckpointing(_.union(otherRDD), false, true) } test("CartesianRDD") { - testCheckpointing(_.cartesian(sc.makeRDD(5 to 6, 4)), 1000) + def otherRDD = sc.makeRDD(1 to 10, 4) + testCheckpointing(_.cartesian(otherRDD)) + testParentCheckpointing(_.cartesian(otherRDD), true, false) } test("CoalescedRDD") { testCheckpointing(new CoalescedRDD(_, 2)) + + // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed + // Current implementation of CoalescedRDDSplit has transient reference to parent RDD, + // so does not serialize the RDD (not need to check its size). + testParentCheckpointing(new CoalescedRDD(_, 2), true, false) + + // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after + // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. + // Note that this test is very specific to the current implementation of CoalescedRDDSplits + val ones = sc.makeRDD(1 to 100, 10).map(x => x) + ones.checkpoint // checkpoint that MappedRDD + val coalesced = new CoalescedRDD(ones, 2) + val splitBeforeCheckpoint = + serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit]) + coalesced.count() // do the checkpointing + val splitAfterCheckpoint = + serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit]) + assert( + splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head, + "CoalescedRDDSplit.parents not updated after parent RDD checkpointed" + ) } test("CoGroupedRDD") { - val rdd2 = sc.makeRDD(5 to 6, 4).map(x => (x % 2, 1)) - testCheckpointing(rdd1 => rdd1.map(x => (x % 2, 1)).cogroup(rdd2)) - testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2)) + // Test serialized size + // RDD with long lineage of one-to-one dependencies through cogroup transformations + val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD() + testCheckpointing(rdd1 => { + CheckpointSuite.cogroup(longLineageRDD1, rdd1.map(x => (x % 2, 1)), partitioner) + }, false, true) - // Special test to make sure that the CoGroupSplit of CoGroupedRDD do not - // hold on to the splits of its parent RDDs, as the splits of parent RDDs - // may change while checkpointing. Rather the splits of parent RDDs must - // be fetched at the time of serialization to ensure the latest splits to - // be sent along with the task. + val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD() + testParentCheckpointing(rdd1 => { + CheckpointSuite.cogroup(longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner) + }, false, true) + } - val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) + /** + * Test checkpointing of the final RDD generated by the given operation. By default, + * this method tests whether the size of serialized RDD has reduced after checkpointing or not. + * It can also test whether the size of serialized RDD splits has reduced after checkpointing or + * not, but this is not done by default as usually the splits do not refer to any RDD and + * therefore never store the lineage. + */ + def testCheckpointing[U: ClassManifest]( + op: (RDD[Int]) => RDD[U], + testRDDSize: Boolean = true, + testRDDSplitSize: Boolean = false + ) { + // Generate the final RDD using given RDD operation + val baseRDD = generateLongLineageRDD + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.headOption.orNull + val rddType = operatedRDD.getClass.getSimpleName - val ones = sc.parallelize(1 to 100, 1).map(x => (x,1)) - val reduced = ones.reduceByKey(_ + _) - val seqOfCogrouped = new ArrayBuffer[RDD[(Int, Int)]]() - seqOfCogrouped += reduced.cogroup(ones).mapValues[Int](add) - for(i <- 1 to 10) { - seqOfCogrouped += seqOfCogrouped.last.cogroup(ones).mapValues(add) - } - val finalCogrouped = seqOfCogrouped.last - val intermediateCogrouped = seqOfCogrouped(5) - - val bytesBeforeCheckpoint = Utils.serialize(finalCogrouped.splits) - intermediateCogrouped.checkpoint() - finalCogrouped.count() - sleep(intermediateCogrouped) - val bytesAfterCheckpoint = Utils.serialize(finalCogrouped.splits) - println("Before = " + bytesBeforeCheckpoint.size + ", after = " + bytesAfterCheckpoint.size) - assert(bytesAfterCheckpoint.size < bytesBeforeCheckpoint.size, - "CoGroupedSplits still holds on to the splits of its parent RDDs") - } - - def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { - val parCollection = sc.makeRDD(1 to 4, 4) - val operatedRDD = op(parCollection) + // Find serialized sizes before and after the checkpoint + val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) operatedRDD.checkpoint() - val parentRDD = operatedRDD.dependencies.head.rdd val result = operatedRDD.collect() - sleep(operatedRDD) - //println(parentRDD + ", " + operatedRDD.dependencies.head.rdd ) + val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + + // Test whether the checkpoint file has been created assert(sc.objectFile[U](operatedRDD.getCheckpointFile.get).collect() === result) + + // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD) + + // Test whether the splits have been changed to the new Hadoop splits + assert(operatedRDD.splits.toList === operatedRDD.checkpointData.cpRDDSplits.toList) + + // Test whether the data in the checkpointed RDD is same as original assert(operatedRDD.collect() === result) + + // Test whether serialized size of the RDD has reduced. If the RDD + // does not have any dependency to another RDD (e.g., ParallelCollection, + // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing. + if (testRDDSize) { + println("Size of " + rddType + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing " + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + // Test whether serialized size of the splits has reduced. If the splits + // do not have any non-transient reference to another RDD or another RDD's splits, it + // does not refer to a lineage and therefore may not reduce in size after checkpointing. + // However, if the original splits before checkpointing do refer to a parent RDD, the splits + // must be forgotten after checkpointing (to remove all reference to parent RDDs) and + // replaced with the HadoopSplits of the checkpointed RDD. + if (testRDDSplitSize) { + println("Size of " + rddType + " splits " + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]") + assert( + splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, + "Size of " + rddType + " splits did not reduce after checkpointing " + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" + ) + } } - def sleep(rdd: RDD[_]) { - val startTime = System.currentTimeMillis() - val maxWaitTime = 5000 - while(rdd.isCheckpointed == false && System.currentTimeMillis() < startTime + maxWaitTime) { - Thread.sleep(50) + /** + * Test whether checkpointing of the parent of the generated RDD also + * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent + * RDDs splits. So even if the parent RDD is checkpointed and its splits changed, + * this RDD will remember the splits and therefore potentially the whole lineage. + */ + def testParentCheckpointing[U: ClassManifest]( + op: (RDD[Int]) => RDD[U], + testRDDSize: Boolean, + testRDDSplitSize: Boolean + ) { + // Generate the final RDD using given RDD operation + val baseRDD = generateLongLineageRDD + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.head.rdd + val rddType = operatedRDD.getClass.getSimpleName + val parentRDDType = parentRDD.getClass.getSimpleName + + // Find serialized sizes before and after the checkpoint + val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one + val result = operatedRDD.collect() + val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + + // Test whether the data in the checkpointed RDD is same as original + assert(operatedRDD.collect() === result) + + // Test whether serialized size of the RDD has reduced because of its parent being + // checkpointed. If this RDD or its parent RDD do not have any dependency + // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may + // not reduce in size after checkpointing. + if (testRDDSize) { + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + // Test whether serialized size of the splits has reduced because of its parent being + // checkpointed. If the splits do not have any non-transient reference to another RDD + // or another RDD's splits, it does not refer to a lineage and therefore may not reduce + // in size after checkpointing. However, if the splits do refer to the *splits* of a parent + // RDD, then these splits must update reference to the parent RDD splits as the parent RDD's + // splits must have changed after checkpointing. + if (testRDDSplitSize) { + assert( + splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, + "Size of " + rddType + " splits did not reduce after checkpointing parent " + parentRDDType + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" + ) } - assert(rdd.isCheckpointed === true, "Waiting for checkpoint to complete took more than " + maxWaitTime + " ms") + + } + + /** + * Generate an RDD with a long lineage of one-to-one dependencies. + */ + def generateLongLineageRDD(): RDD[Int] = { + var rdd = sc.makeRDD(1 to 100, 4) + for (i <- 1 to 20) { + rdd = rdd.map(x => x) + } + rdd + } + + /** + * Generate an RDD with a long lineage specifically for CoGroupedRDD. + * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage + * and narrow dependency with this RDD. This method generate such an RDD by a sequence + * of cogroups and mapValues which creates a long lineage of narrow dependencies. + */ + def generateLongLineageRDDForCoGroupedRDD() = { + val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) + + def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) + + var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones) + for(i <- 1 to 10) { + cogrouped = cogrouped.mapValues(add).cogroup(ones) + } + cogrouped.mapValues(add) + } + + /** + * Get serialized sizes of the RDD and its splits + */ + def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { + (Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size) + } + + /** + * Serialize and deserialize an object. This is useful to verify the objects + * contents after deserialization (e.g., the contents of an RDD split after + * it is sent to a slave along with a task) + */ + def serializeDeserialize[T](obj: T): T = { + val bytes = Utils.serialize(obj) + Utils.deserialize[T](bytes) } } + + +object CheckpointSuite { + // This is a custom cogroup function that does not use mapValues like + // the PairRDDFunctions.cogroup() + def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { + println("First = " + first + ", second = " + second) + new CoGroupedRDD[K]( + Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]), + part + ).asInstanceOf[RDD[(K, Seq[Seq[V]])]] + } + +} \ No newline at end of file -- cgit v1.2.3 From c36ca10241991d46f2f1513b2c0c5e369d8b34f9 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 3 Nov 2012 15:19:41 -0700 Subject: Adding locality aware parallelize --- core/src/main/scala/spark/ParallelCollection.scala | 11 +++++++++-- core/src/main/scala/spark/SparkContext.scala | 10 +++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9725017b61..4bd9e1bd54 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -2,6 +2,7 @@ package spark import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer +import scala.collection.Map private[spark] class ParallelCollectionSplit[T: ClassManifest]( val rddId: Long, @@ -24,7 +25,8 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( private[spark] class ParallelCollection[T: ClassManifest]( @transient sc : SparkContext, @transient data: Seq[T], - numSlices: Int) + numSlices: Int, + locationPrefs : Map[Int,Seq[String]]) extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split @@ -40,7 +42,12 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator - override def preferredLocations(s: Split): Seq[String] = Nil + override def preferredLocations(s: Split): Seq[String] = { + locationPrefs.get(splits_.indexOf(s)) match { + case Some(s) => s + case _ => Nil + } + } } private object ParallelCollection { diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d7b46bee38..7ae1aea993 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -194,7 +194,7 @@ class SparkContext( /** Distribute a local Scala collection to form an RDD. */ def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { - new ParallelCollection[T](this, seq, numSlices) + new ParallelCollection[T](this, seq, numSlices, Map[Int, Seq[String]]()) } /** Distribute a local Scala collection to form an RDD. */ @@ -202,6 +202,14 @@ class SparkContext( parallelize(seq, numSlices) } + /** Distribute a local Scala collection to form an RDD, with one or more + * location preferences for each object. Create a new partition for each + * collection item. */ + def makeLocalityConstrainedRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { + val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap + new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs) + } + /** * Read a text file from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI, and return it as an RDD of Strings. -- cgit v1.2.3 From 3e796bdd57297134ed40b20d7692cd9c8cd6efba Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 7 Dec 2012 19:16:35 -0800 Subject: Changes in response to TD's review. --- core/src/main/scala/spark/SparkContext.scala | 6 +++--- .../scala/spark/streaming/FlumeInputDStream.scala | 2 +- .../spark/streaming/NetworkInputDStream.scala | 4 ++-- .../spark/streaming/NetworkInputTracker.scala | 10 ++++----- .../scala/spark/streaming/RawInputDStream.scala | 2 +- .../scala/spark/streaming/SocketInputDStream.scala | 2 +- .../spark/streaming/examples/FlumeEventCount.scala | 24 +++++++++++++++++----- 7 files changed, 32 insertions(+), 18 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 7ae1aea993..3ccdbfe10e 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -203,9 +203,9 @@ class SparkContext( } /** Distribute a local Scala collection to form an RDD, with one or more - * location preferences for each object. Create a new partition for each - * collection item. */ - def makeLocalityConstrainedRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { + * location preferences (hostnames of Spark nodes) for each object. + * Create a new partition for each collection item. */ + def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs) } diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala index 9c403278c3..2959ce4540 100644 --- a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala @@ -126,5 +126,5 @@ class FlumeReceiver( logInfo("Flume receiver stopped") } - override def getLocationConstraint = Some(host) + override def getLocationPreference = Some(host) } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index 052fc8bb74..4e4e9fc942 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -62,8 +62,8 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri /** This method will be called to stop receiving data. */ protected def onStop() - /** This method conveys a placement constraint (hostname) for this receiver. */ - def getLocationConstraint() : Option[String] = None + /** This method conveys a placement preference (hostname) for this receiver. */ + def getLocationPreference() : Option[String] = None /** * This method starts the receiver. First is accesses all the lazy members to diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 56661c2615..b421f795ee 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -99,13 +99,13 @@ class NetworkInputTracker( def startReceivers() { val receivers = networkInputStreams.map(_.createReceiver()) - // We only honor constraints if all receivers have them - val hasLocationConstraints = receivers.map(_.getLocationConstraint().isDefined).reduce(_ && _) + // Right now, we only honor preferences if all receivers have them + val hasLocationPreferences = receivers.map(_.getLocationPreference().isDefined).reduce(_ && _) val tempRDD = - if (hasLocationConstraints) { - val receiversWithConstraints = receivers.map(r => (r, Seq(r.getLocationConstraint().toString))) - ssc.sc.makeLocalityConstrainedRDD[NetworkReceiver[_]](receiversWithConstraints) + if (hasLocationPreferences) { + val receiversWithPreferences = receivers.map(r => (r, Seq(r.getLocationPreference().toString))) + ssc.sc.makeRDD[NetworkReceiver[_]](receiversWithPreferences) } else { ssc.sc.makeRDD(receivers, receivers.size) diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index fd51ed47a5..6acaa9aab1 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -31,7 +31,7 @@ class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: S var blockPushingThread: Thread = null - override def getLocationConstraint = None + override def getLocationPreference = None def onStart() { // Open a socket to the target address and keep reading from it diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala index ebbb17a39a..a9e37c0ff0 100644 --- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala @@ -34,7 +34,7 @@ class SocketReceiver[T: ClassManifest]( lazy protected val dataHandler = new DataHandler(this, storageLevel) - override def getLocationConstraint = None + override def getLocationPreference = None protected def onStart() { logInfo("Connecting to " + host + ":" + port) diff --git a/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala index d76c92fdd5..e60ce483a3 100644 --- a/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala +++ b/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala @@ -4,19 +4,33 @@ import spark.util.IntParam import spark.storage.StorageLevel import spark.streaming._ +/** + * Produce a streaming count of events received from Flume. + * + * This should be used in conjunction with an AvroSink in Flume. It will start + * an Avro server on at the request host:port address and listen for requests. + * Your Flume AvroSink should be pointed to this address. + * + * Usage: FlumeEventCount + * + * is a Spark master URL + * is the host the Flume receiver will be started on - a receiver + * creates a server and listens for flume events. + * is the port the Flume receiver will listen on. + */ object FlumeEventCount { def main(args: Array[String]) { - if (args.length != 4) { + if (args.length != 3) { System.err.println( - "Usage: FlumeEventCount ") + "Usage: FlumeEventCount ") System.exit(1) } - val Array(master, host, IntParam(port), IntParam(batchMillis)) = args + val Array(master, host, IntParam(port)) = args + val batchInterval = Milliseconds(2000) // Create the context and set the batch size - val ssc = new StreamingContext(master, "FlumeEventCount", - Milliseconds(batchMillis)) + val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval) // Create a flume stream val stream = ssc.flumeStream(host,port,StorageLevel.MEMORY_ONLY) -- cgit v1.2.3 From e42721601898ff199ca1c6cfeae159ad3ef691e3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 8 Dec 2012 12:46:59 -0800 Subject: Removed unnecessary testcases. --- core/src/test/scala/spark/CheckpointSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 8622ce92aa..41d84cb01c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -121,7 +121,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(bytesAfterCheckpoint.size < bytesBeforeCheckpoint.size, "CoGroupedSplits still holds on to the splits of its parent RDDs") } - + /* /** * This test forces two ResultTasks of the same job to be launched before and after * the checkpointing of job's RDD is completed. @@ -149,7 +149,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { } testThreading(op1, op2) } - + */ def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { val parCollection = sc.makeRDD(1 to 4, 4) @@ -163,7 +163,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(operatedRDD.dependencies.head.rdd != parentRDD) assert(operatedRDD.collect() === result) } - + /* def testThreading[U: ClassManifest, V: ClassManifest](op1: (RDD[Int]) => RDD[U], op2: (RDD[U]) => RDD[V]) { val parCollection = sc.makeRDD(1 to 2, 2) @@ -201,7 +201,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val correctResult = secondRDD.collect() assert(result === correctResult) } - + */ def sleep(rdd: RDD[_]) { val startTime = System.currentTimeMillis() val maxWaitTime = 5000 -- cgit v1.2.3 From 746afc2e6513d5f32f261ec0dbf2823f78a5e960 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 10 Dec 2012 23:36:37 -0800 Subject: Bunch of bug fixes related to checkpointing in RDDs. RDDCheckpointData object is used to lock all serialization and dependency changes for checkpointing. ResultTask converted to Externalizable and serialized RDD is cached like ShuffleMapTask. --- core/src/main/scala/spark/ParallelCollection.scala | 10 +- core/src/main/scala/spark/RDD.scala | 10 +- core/src/main/scala/spark/RDDCheckpointData.scala | 76 +++++++-- core/src/main/scala/spark/SparkContext.scala | 5 +- core/src/main/scala/spark/rdd/BlockRDD.scala | 9 +- core/src/main/scala/spark/rdd/CartesianRDD.scala | 21 ++- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 18 +- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 8 +- core/src/main/scala/spark/rdd/HadoopRDD.scala | 4 + core/src/main/scala/spark/rdd/UnionRDD.scala | 15 +- .../main/scala/spark/scheduler/ResultTask.scala | 95 ++++++++++- .../scala/spark/scheduler/ShuffleMapTask.scala | 21 ++- core/src/test/scala/spark/CheckpointSuite.scala | 187 ++++++++++++++++++--- 13 files changed, 389 insertions(+), 90 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9725017b61..9d12af6912 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -28,10 +28,11 @@ private[spark] class ParallelCollection[T: ClassManifest]( extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split - // instead. UPDATE: With the new changes to enable checkpointing, this an be done. + // instead. + // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. @transient - val splits_ = { + var splits_ : Array[Split] = { val slices = ParallelCollection.slice(data, numSlices).toArray slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray } @@ -41,6 +42,11 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def preferredLocations(s: Split): Seq[String] = Nil + + override def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + } } private object ParallelCollection { diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e9bd131e61..efa03d5185 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -163,7 +163,7 @@ abstract class RDD[T: ClassManifest]( final def iterator(split: Split): Iterator[T] = { if (isCheckpointed) { // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original - checkpointData.iterator(split.index) + checkpointData.iterator(split) } else if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { @@ -556,16 +556,12 @@ abstract class RDD[T: ClassManifest]( @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - synchronized { - oos.defaultWriteObject() - } + oos.defaultWriteObject() } @throws(classOf[IOException]) private def readObject(ois: ObjectInputStream) { - synchronized { - ois.defaultReadObject() - } + ois.defaultReadObject() } } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index eb4482acee..ff2ed4cdfc 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -1,12 +1,20 @@ package spark import org.apache.hadoop.fs.Path +import rdd.CoalescedRDD +import scheduler.{ResultTask, ShuffleMapTask} - +/** + * This class contains all the information of the regarding RDD checkpointing. + */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) -extends Serializable { +extends Logging with Serializable { + /** + * This class manages the state transition of an RDD through checkpointing + * [ Not checkpointed --> marked for checkpointing --> checkpointing in progress --> checkpointed ] + */ class CheckpointState extends Serializable { var state = 0 @@ -20,24 +28,30 @@ extends Serializable { } val cpState = new CheckpointState() - var cpFile: Option[String] = None - var cpRDD: Option[RDD[T]] = None - var cpRDDSplits: Seq[Split] = Nil + @transient var cpFile: Option[String] = None + @transient var cpRDD: Option[RDD[T]] = None + @transient var cpRDDSplits: Seq[Split] = Nil + // Mark the RDD for checkpointing def markForCheckpoint() = { - rdd.synchronized { cpState.mark() } + RDDCheckpointData.synchronized { cpState.mark() } } + // Is the RDD already checkpointed def isCheckpointed() = { - rdd.synchronized { cpState.isCheckpointed } + RDDCheckpointData.synchronized { cpState.isCheckpointed } } + // Get the file to which this RDD was checkpointed to as a Option def getCheckpointFile() = { - rdd.synchronized { cpFile } + RDDCheckpointData.synchronized { cpFile } } + // Do the checkpointing of the RDD. Called after the first job using that RDD is over. def doCheckpoint() { - rdd.synchronized { + // If it is marked for checkpointing AND checkpointing is not already in progress, + // then set it to be in progress, else return + RDDCheckpointData.synchronized { if (cpState.isMarked && !cpState.isInProgress) { cpState.start() } else { @@ -45,24 +59,56 @@ extends Serializable { } } + // Save to file, and reload it as an RDD val file = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString rdd.saveAsObjectFile(file) - val newRDD = rdd.context.objectFile[T](file, rdd.splits.size) - rdd.synchronized { - rdd.changeDependencies(newRDD) + val newRDD = { + val hadoopRDD = rdd.context.objectFile[T](file, rdd.splits.size) + + val oldSplits = rdd.splits.size + val newSplits = hadoopRDD.splits.size + + logDebug("RDD splits = " + oldSplits + " --> " + newSplits) + if (newSplits < oldSplits) { + throw new Exception("# splits after checkpointing is less than before " + + "[" + oldSplits + " --> " + newSplits) + } else if (newSplits > oldSplits) { + new CoalescedRDD(hadoopRDD, rdd.splits.size) + } else { + hadoopRDD + } + } + logDebug("New RDD has " + newRDD.splits.size + " splits") + + // Change the dependencies and splits of the RDD + RDDCheckpointData.synchronized { cpFile = Some(file) cpRDD = Some(newRDD) cpRDDSplits = newRDD.splits + rdd.changeDependencies(newRDD) cpState.finish() + RDDCheckpointData.checkpointCompleted() + logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } } + // Get preferred location of a split after checkpointing def preferredLocations(split: Split) = { - cpRDD.get.preferredLocations(split) + RDDCheckpointData.synchronized { + cpRDD.get.preferredLocations(split) + } } - def iterator(splitIndex: Int): Iterator[T] = { - cpRDD.get.iterator(cpRDDSplits(splitIndex)) + // Get iterator. This is called at the worker nodes. + def iterator(split: Split): Iterator[T] = { + rdd.firstParent[T].iterator(split) + } +} + +private[spark] object RDDCheckpointData { + def checkpointCompleted() { + ShuffleMapTask.clearCache() + ResultTask.clearCache() } } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d7b46bee38..654b1c2eb7 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -40,9 +40,7 @@ import spark.partial.PartialResult import spark.rdd.HadoopRDD import spark.rdd.NewHadoopRDD import spark.rdd.UnionRDD -import spark.scheduler.ShuffleMapTask -import spark.scheduler.DAGScheduler -import spark.scheduler.TaskScheduler +import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} @@ -486,6 +484,7 @@ class SparkContext( clearJars() SparkEnv.set(null) ShuffleMapTask.clearCache() + ResultTask.clearCache() logInfo("Successfully stopped SparkContext") } diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index 590f9eb738..0c8cdd10dd 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -2,7 +2,7 @@ package spark.rdd import scala.collection.mutable.HashMap -import spark.Dependency +import spark.OneToOneDependency import spark.RDD import spark.SparkContext import spark.SparkEnv @@ -17,7 +17,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St extends RDD[T](sc, Nil) { @transient - val splits_ = (0 until blockIds.size).map(i => { + var splits_ : Array[Split] = (0 until blockIds.size).map(i => { new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] }).toArray @@ -43,5 +43,10 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St override def preferredLocations(split: Split) = locations_(split.asInstanceOf[BlockRDDSplit].blockId) + + override def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + } } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 1d753a5168..9975e79b08 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,10 +1,27 @@ package spark.rdd import spark._ +import java.io.{ObjectOutputStream, IOException} private[spark] -class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { +class CartesianSplit( + idx: Int, + @transient rdd1: RDD[_], + @transient rdd2: RDD[_], + s1Index: Int, + s2Index: Int + ) extends Split { + var s1 = rdd1.splits(s1Index) + var s2 = rdd2.splits(s2Index) override val index: Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + s1 = rdd1.splits(s1Index) + s2 = rdd2.splits(s2Index) + oos.defaultWriteObject() + } } private[spark] @@ -23,7 +40,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { val idx = s1.index * numSplitsInRdd2 + s2.index - array(idx) = new CartesianSplit(idx, s1, s2) + array(idx) = new CartesianSplit(idx, rdd1, rdd2, s1.index, s2.index) } array } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 57d472666b..e4e70b13ba 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -20,11 +20,9 @@ private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, va @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - rdd.synchronized { - // Update the reference to parent split at the time of task serialization - split = rdd.splits(splitIndex) - oos.defaultWriteObject() - } + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() } } private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep @@ -42,7 +40,8 @@ private[spark] class CoGroupAggregator { (b1, b2) => b1 ++ b2 }) with Serializable -class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) +class +CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging { val aggr = new CoGroupAggregator @@ -63,7 +62,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) deps.toList } - override def dependencies = deps_ + // Pre-checkpoint dependencies deps_ should be transient (deps_) + // but post-checkpoint dependencies must not be transient (dependencies_) + override def dependencies = if (isCheckpointed) dependencies_ else deps_ @transient var splits_ : Array[Split] = { @@ -114,7 +115,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) } override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + deps_ = null + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) splits_ = newRDD.splits rdds = null } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 0b4499e2eb..088958942e 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -12,11 +12,9 @@ private[spark] case class CoalescedRDDSplit( @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - rdd.synchronized { - // Update the reference to parent split at the time of task serialization - parents = parentsIndices.map(rdd.splits(_)) - oos.defaultWriteObject() - } + // Update the reference to parent split at the time of task serialization + parents = parentsIndices.map(rdd.splits(_)) + oos.defaultWriteObject() } } diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index a12531ea89..af54f23ebc 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -115,4 +115,8 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } + + override def checkpoint() { + // Do nothing. Hadoop RDD cannot be checkpointed. + } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index a5948dd1f1..808729f18d 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -19,11 +19,9 @@ private[spark] class UnionSplit[T: ClassManifest]( @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - rdd.synchronized { - // Update the reference to parent split at the time of task serialization - split = rdd.splits(splitIndex) - oos.defaultWriteObject() - } + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() } } @@ -55,7 +53,9 @@ class UnionRDD[T: ClassManifest]( deps.toList } - override def dependencies = deps_ + // Pre-checkpoint dependencies deps_ should be transient (deps_) + // but post-checkpoint dependencies must not be transient (dependencies_) + override def dependencies = if (isCheckpointed) dependencies_ else deps_ override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() @@ -63,7 +63,8 @@ class UnionRDD[T: ClassManifest]( s.asInstanceOf[UnionSplit[T]].preferredLocations() override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD)) + deps_ = null + dependencies_ = List(new OneToOneDependency(newRDD)) splits_ = newRDD.splits rdds = null } diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index 2ebd4075a2..bcb9e4956b 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -1,17 +1,73 @@ package spark.scheduler import spark._ +import java.io._ +import util.{MetadataCleaner, TimeStampedHashMap} +import java.util.zip.{GZIPInputStream, GZIPOutputStream} + +private[spark] object ResultTask { + + // A simple map between the stage id to the serialized byte array of a task. + // Served as a cache for task serialization because serialization can be + // expensive on the master node if it needs to launch thousands of tasks. + val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] + + val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.cleanup) + + def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { + synchronized { + val old = serializedInfoCache.get(stageId).orNull + if (old != null) { + return old + } else { + val out = new ByteArrayOutputStream + val ser = SparkEnv.get.closureSerializer.newInstance + val objOut = ser.serializeStream(new GZIPOutputStream(out)) + objOut.writeObject(rdd) + objOut.writeObject(func) + objOut.close() + val bytes = out.toByteArray + serializedInfoCache.put(stageId, bytes) + return bytes + } + } + } + + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { + synchronized { + val loader = Thread.currentThread.getContextClassLoader + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val ser = SparkEnv.get.closureSerializer.newInstance + val objIn = ser.deserializeStream(in) + val rdd = objIn.readObject().asInstanceOf[RDD[_]] + val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] + return (rdd, func) + } + } + + def clearCache() { + synchronized { + serializedInfoCache.clear() + } + } +} + private[spark] class ResultTask[T, U]( stageId: Int, - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - val partition: Int, + var rdd: RDD[T], + var func: (TaskContext, Iterator[T]) => U, + var partition: Int, @transient locs: Seq[String], val outputId: Int) - extends Task[U](stageId) { - - val split = rdd.splits(partition) + extends Task[U](stageId) with Externalizable { + + def this() = this(0, null, null, 0, null, 0) + var split = if (rdd == null) { + null + } else { + rdd.splits(partition) + } override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) @@ -21,4 +77,31 @@ private[spark] class ResultTask[T, U]( override def preferredLocations: Seq[String] = locs override def toString = "ResultTask(" + stageId + ", " + partition + ")" + + override def writeExternal(out: ObjectOutput) { + RDDCheckpointData.synchronized { + split = rdd.splits(partition) + out.writeInt(stageId) + val bytes = ResultTask.serializeInfo( + stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partition) + out.writeInt(outputId) + out.writeObject(split) + } + } + + override def readExternal(in: ObjectInput) { + val stageId = in.readInt() + val numBytes = in.readInt() + val bytes = new Array[Byte](numBytes) + in.readFully(bytes) + val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) + rdd = rdd_.asInstanceOf[RDD[T]] + func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] + partition = in.readInt() + val outputId = in.readInt() + split = in.readObject().asInstanceOf[Split] + } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 683f5ebec3..5d28c40778 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -84,19 +84,22 @@ private[spark] class ShuffleMapTask( def this() = this(0, null, null, 0, null) var split = if (rdd == null) { - null - } else { + null + } else { rdd.splits(partition) } override def writeExternal(out: ObjectOutput) { - out.writeInt(stageId) - val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partition) - out.writeLong(generation) - out.writeObject(split) + RDDCheckpointData.synchronized { + split = rdd.splits(partition) + out.writeInt(stageId) + val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partition) + out.writeLong(generation) + out.writeObject(split) + } } override def readExternal(in: ObjectInput) { diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 51bd59e2b1..7b323e089c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -34,13 +34,30 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { } } + test("RDDs with one-to-one dependencies") { + testCheckpointing(_.map(x => x.toString)) + testCheckpointing(_.flatMap(x => 1 to x)) + testCheckpointing(_.filter(_ % 2 == 0)) + testCheckpointing(_.sample(false, 0.5, 0)) + testCheckpointing(_.glom()) + testCheckpointing(_.mapPartitions(_.map(_.toString))) + testCheckpointing(r => new MapPartitionsWithSplitRDD(r, + (i: Int, iter: Iterator[Int]) => iter.map(_.toString) )) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) + testCheckpointing(_.pipe(Seq("cat"))) + } + test("ParallelCollection") { - val parCollection = sc.makeRDD(1 to 4) + val parCollection = sc.makeRDD(1 to 4, 2) + val numSplits = parCollection.splits.size parCollection.checkpoint() assert(parCollection.dependencies === Nil) val result = parCollection.collect() assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) + assert(parCollection.splits.length === numSplits) + assert(parCollection.splits.toList === parCollection.checkpointData.cpRDDSplits.toList) assert(parCollection.collect() === result) } @@ -49,44 +66,58 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val blockManager = SparkEnv.get.blockManager blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) val blockRDD = new BlockRDD[String](sc, Array(blockId)) + val numSplits = blockRDD.splits.size blockRDD.checkpoint() val result = blockRDD.collect() assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) + assert(blockRDD.splits.length === numSplits) + assert(blockRDD.splits.toList === blockRDD.checkpointData.cpRDDSplits.toList) assert(blockRDD.collect() === result) } - test("RDDs with one-to-one dependencies") { - testCheckpointing(_.map(x => x.toString)) - testCheckpointing(_.flatMap(x => 1 to x)) - testCheckpointing(_.filter(_ % 2 == 0)) - testCheckpointing(_.sample(false, 0.5, 0)) - testCheckpointing(_.glom()) - testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithSplitRDD(r, - (i: Int, iter: Iterator[Int]) => iter.map(_.toString) )) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) - testCheckpointing(_.pipe(Seq("cat"))) - } - test("ShuffledRDD") { - // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD testCheckpointing(rdd => { + // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD new ShuffledRDD(rdd.map(x => (x % 2, 1)), partitioner) }) } test("UnionRDD") { def otherRDD = sc.makeRDD(1 to 10, 4) + + // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed. + // Current implementation of UnionRDD has transient reference to parent RDDs, + // so only the splits will reduce in serialized size, not the RDD. testCheckpointing(_.union(otherRDD), false, true) testParentCheckpointing(_.union(otherRDD), false, true) } test("CartesianRDD") { - def otherRDD = sc.makeRDD(1 to 10, 4) - testCheckpointing(_.cartesian(otherRDD)) - testParentCheckpointing(_.cartesian(otherRDD), true, false) + def otherRDD = sc.makeRDD(1 to 10, 1) + testCheckpointing(new CartesianRDD(sc, _, otherRDD)) + + // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed + // Current implementation of CoalescedRDDSplit has transient reference to parent RDD, + // so only the RDD will reduce in serialized size, not the splits. + testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false) + + // Test that the CartesianRDD updates parent splits (CartesianRDD.s1/s2) after + // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. + // Note that this test is very specific to the current implementation of CartesianRDD. + val ones = sc.makeRDD(1 to 100, 10).map(x => x) + ones.checkpoint // checkpoint that MappedRDD + val cartesian = new CartesianRDD(sc, ones, ones) + val splitBeforeCheckpoint = + serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit]) + cartesian.count() // do the checkpointing + val splitAfterCheckpoint = + serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit]) + assert( + (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) && + (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2), + "CartesianRDD.parents not updated after parent RDD checkpointed" + ) } test("CoalescedRDD") { @@ -94,7 +125,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed // Current implementation of CoalescedRDDSplit has transient reference to parent RDD, - // so does not serialize the RDD (not need to check its size). + // so only the RDD will reduce in serialized size, not the splits. testParentCheckpointing(new CoalescedRDD(_, 2), true, false) // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after @@ -145,13 +176,14 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val operatedRDD = op(baseRDD) val parentRDD = operatedRDD.dependencies.headOption.orNull val rddType = operatedRDD.getClass.getSimpleName + val numSplits = operatedRDD.splits.length // Find serialized sizes before and after the checkpoint val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) operatedRDD.checkpoint() val result = operatedRDD.collect() val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - + // Test whether the checkpoint file has been created assert(sc.objectFile[U](operatedRDD.getCheckpointFile.get).collect() === result) @@ -160,6 +192,9 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // Test whether the splits have been changed to the new Hadoop splits assert(operatedRDD.splits.toList === operatedRDD.checkpointData.cpRDDSplits.toList) + + // Test whether the number of splits is same as before + assert(operatedRDD.splits.length === numSplits) // Test whether the data in the checkpointed RDD is same as original assert(operatedRDD.collect() === result) @@ -168,7 +203,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // does not have any dependency to another RDD (e.g., ParallelCollection, // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing. if (testRDDSize) { - println("Size of " + rddType + + logInfo("Size of " + rddType + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") assert( rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, @@ -184,7 +219,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // must be forgotten after checkpointing (to remove all reference to parent RDDs) and // replaced with the HadoopSplits of the checkpointed RDD. if (testRDDSplitSize) { - println("Size of " + rddType + " splits " + logInfo("Size of " + rddType + " splits " + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]") assert( splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, @@ -294,14 +329,118 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val bytes = Utils.serialize(obj) Utils.deserialize[T](bytes) } + /* + test("Consistency check for ResultTask") { + // Time -----------------------> + // Core 1: |<- count in thread 1, task 1 ->| |<-- checkpoint, task 1 ---->| |<- count in thread 2, task 2 ->| + // Core 2: |<- count in thread 1, task 2 ->| |<--- checkpoint, task 2 ---------->| |<- count in thread 2, task 1 ->| + // | + // checkpoint completed + sc.stop(); sc = null + System.clearProperty("spark.master.port") + + val dir = File.createTempFile("temp_", "") + dir.delete() + val ctxt = new SparkContext("local[2]", "ResultTask") + ctxt.setCheckpointDir(dir.toString) + + try { + val rdd = ctxt.makeRDD(1 to 2, 2).map(x => { + val state = CheckpointSuite.incrementState() + println("State = " + state) + if (state <= 3) { + // If executing the two tasks for the job comouting rdd.count + // of thread 1, or the first task for the recomputation due + // to checkpointing (saveing to HDFS), then do nothing + } else if (state == 4) { + // If executing the second task for the recomputation due to + // checkpointing. then prolong this task, to allow rdd.count + // of thread 2 to start before checkpoint of this RDD is completed + + Thread.sleep(1000) + println("State = " + state + " wake up") + } else { + // Else executing the tasks from thread 2 + Thread.sleep(1000) + println("State = " + state + " wake up") + } + + (x, 1) + }) + rdd.checkpoint() + val env = SparkEnv.get + + val thread1 = new Thread() { + override def run() { + try { + SparkEnv.set(env) + rdd.count() + } catch { + case e: Exception => CheckpointSuite.failed("Exception in thread 1", e) + } + } + } + thread1.start() + + val thread2 = new Thread() { + override def run() { + try { + SparkEnv.set(env) + CheckpointSuite.waitTillState(3) + println("\n\n\n\n") + rdd.count() + } catch { + case e: Exception => CheckpointSuite.failed("Exception in thread 2", e) + } + } + } + thread2.start() + + thread1.join() + thread2.join() + } finally { + dir.delete() + } + + assert(!CheckpointSuite.failed, CheckpointSuite.failureMessage) + + ctxt.stop() + + } + */ } object CheckpointSuite { + /* + var state = 0 + var failed = false + var failureMessage = "" + + def incrementState(): Int = { + this.synchronized { state += 1; this.notifyAll(); state } + } + + def getState(): Int = { + this.synchronized( state ) + } + + def waitTillState(s: Int) { + while(state < s) { + this.synchronized { this.wait() } + } + } + + def failed(msg: String, ex: Exception) { + failed = true + failureMessage += msg + "\n" + ex + "\n\n" + } + */ + // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { - println("First = " + first + ", second = " + second) + //println("First = " + first + ", second = " + second) new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]), part -- cgit v1.2.3 From 2a87d816a24c62215d682e3a7af65489c0d6e708 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Dec 2012 01:44:43 -0800 Subject: Added clear property to JavaAPISuite to remove port binding errors. --- core/src/test/scala/spark/JavaAPISuite.java | 2 ++ 1 file changed, 2 insertions(+) (limited to 'core') diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 5875506179..6bd9836a93 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -44,6 +44,8 @@ public class JavaAPISuite implements Serializable { public void tearDown() { sc.stop(); sc = null; + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port"); } static class ReverseIntComparator implements Comparator, Serializable { -- cgit v1.2.3 From fa28f25619d6712e5f920f498ec03085ea208b4d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Dec 2012 13:59:43 -0800 Subject: Fixed bug in UnionRDD and CoGroupedRDD --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 9 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 12 +-- core/src/test/scala/spark/CheckpointSuite.scala | 104 ----------------------- 3 files changed, 10 insertions(+), 115 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index e4e70b13ba..bc6d16ee8b 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -15,8 +15,11 @@ import spark.Split import java.io.{ObjectOutputStream, IOException} private[spark] sealed trait CoGroupSplitDep extends Serializable -private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, var split: Split = null) - extends CoGroupSplitDep { +private[spark] case class NarrowCoGroupSplitDep( + rdd: RDD[_], + splitIndex: Int, + var split: Split + ) extends CoGroupSplitDep { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { @@ -75,7 +78,7 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) case s: ShuffleDependency[_, _] => new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep case _ => - new NarrowCoGroupSplitDep(r, i): CoGroupSplitDep + new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep } }.toList) } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 808729f18d..a84867492b 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -5,14 +5,10 @@ import scala.collection.mutable.ArrayBuffer import spark._ import java.io.{ObjectOutputStream, IOException} -private[spark] class UnionSplit[T: ClassManifest]( - idx: Int, - rdd: RDD[T], - splitIndex: Int, - var split: Split = null) - extends Split - with Serializable { - +private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int) + extends Split { + var split: Split = rdd.splits(splitIndex) + def iterator() = rdd.iterator(split) def preferredLocations() = rdd.preferredLocations(split) override val index: Int = idx diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 7b323e089c..909c55c91c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -329,114 +329,10 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val bytes = Utils.serialize(obj) Utils.deserialize[T](bytes) } - /* - test("Consistency check for ResultTask") { - // Time -----------------------> - // Core 1: |<- count in thread 1, task 1 ->| |<-- checkpoint, task 1 ---->| |<- count in thread 2, task 2 ->| - // Core 2: |<- count in thread 1, task 2 ->| |<--- checkpoint, task 2 ---------->| |<- count in thread 2, task 1 ->| - // | - // checkpoint completed - sc.stop(); sc = null - System.clearProperty("spark.master.port") - - val dir = File.createTempFile("temp_", "") - dir.delete() - val ctxt = new SparkContext("local[2]", "ResultTask") - ctxt.setCheckpointDir(dir.toString) - - try { - val rdd = ctxt.makeRDD(1 to 2, 2).map(x => { - val state = CheckpointSuite.incrementState() - println("State = " + state) - if (state <= 3) { - // If executing the two tasks for the job comouting rdd.count - // of thread 1, or the first task for the recomputation due - // to checkpointing (saveing to HDFS), then do nothing - } else if (state == 4) { - // If executing the second task for the recomputation due to - // checkpointing. then prolong this task, to allow rdd.count - // of thread 2 to start before checkpoint of this RDD is completed - - Thread.sleep(1000) - println("State = " + state + " wake up") - } else { - // Else executing the tasks from thread 2 - Thread.sleep(1000) - println("State = " + state + " wake up") - } - - (x, 1) - }) - rdd.checkpoint() - val env = SparkEnv.get - - val thread1 = new Thread() { - override def run() { - try { - SparkEnv.set(env) - rdd.count() - } catch { - case e: Exception => CheckpointSuite.failed("Exception in thread 1", e) - } - } - } - thread1.start() - - val thread2 = new Thread() { - override def run() { - try { - SparkEnv.set(env) - CheckpointSuite.waitTillState(3) - println("\n\n\n\n") - rdd.count() - } catch { - case e: Exception => CheckpointSuite.failed("Exception in thread 2", e) - } - } - } - thread2.start() - - thread1.join() - thread2.join() - } finally { - dir.delete() - } - - assert(!CheckpointSuite.failed, CheckpointSuite.failureMessage) - - ctxt.stop() - - } - */ } object CheckpointSuite { - /* - var state = 0 - var failed = false - var failureMessage = "" - - def incrementState(): Int = { - this.synchronized { state += 1; this.notifyAll(); state } - } - - def getState(): Int = { - this.synchronized( state ) - } - - def waitTillState(s: Int) { - while(state < s) { - this.synchronized { this.wait() } - } - } - - def failed(msg: String, ex: Exception) { - failed = true - failureMessage += msg + "\n" + ex + "\n\n" - } - */ - // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { -- cgit v1.2.3 From 8e74fac215e8b9cda7e35111c5116e3669c6eb97 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Dec 2012 15:36:12 -0800 Subject: Made checkpoint data in RDDs optional to further reduce serialized size. --- core/src/main/scala/spark/RDD.scala | 19 +++++++++++-------- core/src/main/scala/spark/SparkContext.scala | 11 +++++++++++ core/src/test/scala/spark/CheckpointSuite.scala | 12 ++++++------ .../src/main/scala/spark/streaming/DStream.scala | 4 +--- 4 files changed, 29 insertions(+), 17 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index efa03d5185..6c04769c82 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -112,7 +112,7 @@ abstract class RDD[T: ClassManifest]( // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE - protected[spark] val checkpointData = new RDDCheckpointData(this) + protected[spark] var checkpointData: Option[RDDCheckpointData[T]] = None /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassManifest] = { @@ -149,7 +149,7 @@ abstract class RDD[T: ClassManifest]( def getPreferredLocations(split: Split) = { if (isCheckpointed) { - checkpointData.preferredLocations(split) + checkpointData.get.preferredLocations(split) } else { preferredLocations(split) } @@ -163,7 +163,7 @@ abstract class RDD[T: ClassManifest]( final def iterator(split: Split): Iterator[T] = { if (isCheckpointed) { // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original - checkpointData.iterator(split) + checkpointData.get.iterator(split) } else if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { @@ -516,21 +516,24 @@ abstract class RDD[T: ClassManifest]( * require recomputation. */ def checkpoint() { - checkpointData.markForCheckpoint() + if (checkpointData.isEmpty) { + checkpointData = Some(new RDDCheckpointData(this)) + checkpointData.get.markForCheckpoint() + } } /** * Return whether this RDD has been checkpointed or not */ def isCheckpointed(): Boolean = { - checkpointData.isCheckpointed() + if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false } /** * Gets the name of the file to which this RDD was checkpointed */ def getCheckpointFile(): Option[String] = { - checkpointData.getCheckpointFile() + if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None } /** @@ -539,12 +542,12 @@ abstract class RDD[T: ClassManifest]( * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. */ protected[spark] def doCheckpoint() { - checkpointData.doCheckpoint() + if (checkpointData.isDefined) checkpointData.get.doCheckpoint() dependencies.foreach(_.rdd.doCheckpoint()) } /** - * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]] + * Changes the dependencies of this RDD from its original parents to the new RDD * (`newRDD`) created from the checkpoint file. This method must ensure that all references * to the original parent RDDs must be removed to enable the parent RDDs to be garbage * collected. Subclasses of RDD may override this method for implementing their own changing diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 654b1c2eb7..71ed4ef058 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -366,6 +366,17 @@ class SparkContext( .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes)) } + + protected[spark] def checkpointFile[T: ClassManifest]( + path: String, + minSplits: Int = defaultMinSplits + ): RDD[T] = { + val rdd = objectFile[T](path, minSplits) + rdd.checkpointData = Some(new RDDCheckpointData(rdd)) + rdd.checkpointData.get.cpFile = Some(path) + rdd + } + /** Build the union of a list of RDDs. */ def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 909c55c91c..0bffedb8db 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -57,7 +57,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.splits.length === numSplits) - assert(parCollection.splits.toList === parCollection.checkpointData.cpRDDSplits.toList) + assert(parCollection.splits.toList === parCollection.checkpointData.get.cpRDDSplits.toList) assert(parCollection.collect() === result) } @@ -72,7 +72,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.splits.length === numSplits) - assert(blockRDD.splits.toList === blockRDD.checkpointData.cpRDDSplits.toList) + assert(blockRDD.splits.toList === blockRDD.checkpointData.get.cpRDDSplits.toList) assert(blockRDD.collect() === result) } @@ -84,7 +84,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { } test("UnionRDD") { - def otherRDD = sc.makeRDD(1 to 10, 4) + def otherRDD = sc.makeRDD(1 to 10, 1) // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed. // Current implementation of UnionRDD has transient reference to parent RDDs, @@ -191,7 +191,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(operatedRDD.dependencies.head.rdd != parentRDD) // Test whether the splits have been changed to the new Hadoop splits - assert(operatedRDD.splits.toList === operatedRDD.checkpointData.cpRDDSplits.toList) + assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.cpRDDSplits.toList) // Test whether the number of splits is same as before assert(operatedRDD.splits.length === numSplits) @@ -289,8 +289,8 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { */ def generateLongLineageRDD(): RDD[Int] = { var rdd = sc.makeRDD(1 to 100, 4) - for (i <- 1 to 20) { - rdd = rdd.map(x => x) + for (i <- 1 to 50) { + rdd = rdd.map(x => x + 1) } rdd } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index d290c5927e..69fefa21a0 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -372,9 +372,7 @@ extends Serializable with Logging { checkpointData.foreach { case(time, data) => { logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'") - val rdd = ssc.sc.objectFile[T](data.toString) - // Set the checkpoint file name to identify this RDD as a checkpointed RDD by updateCheckpointData() - rdd.checkpointData.cpFile = Some(data.toString) + val rdd = ssc.sc.checkpointFile[T](data.toString) generatedRDDs += ((time, rdd)) } } -- cgit v1.2.3 From 72eed2b95edb3b0b213517c815e09c3886b11669 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 17 Dec 2012 18:52:43 -0800 Subject: Converted CheckpointState in RDDCheckpointData to use scala Enumeration. --- core/src/main/scala/spark/RDDCheckpointData.scala | 48 +++++++++++------------ 1 file changed, 22 insertions(+), 26 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index ff2ed4cdfc..7613b338e6 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -5,45 +5,41 @@ import rdd.CoalescedRDD import scheduler.{ResultTask, ShuffleMapTask} /** - * This class contains all the information of the regarding RDD checkpointing. + * Enumeration to manage state transitions of an RDD through checkpointing + * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] */ +private[spark] object CheckpointState extends Enumeration { + type CheckpointState = Value + val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value +} +/** + * This class contains all the information of the regarding RDD checkpointing. + */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) extends Logging with Serializable { - /** - * This class manages the state transition of an RDD through checkpointing - * [ Not checkpointed --> marked for checkpointing --> checkpointing in progress --> checkpointed ] - */ - class CheckpointState extends Serializable { - var state = 0 + import CheckpointState._ - def mark() { if (state == 0) state = 1 } - def start() { assert(state == 1); state = 2 } - def finish() { assert(state == 2); state = 3 } - - def isMarked() = { state == 1 } - def isInProgress = { state == 2 } - def isCheckpointed = { state == 3 } - } - - val cpState = new CheckpointState() + var cpState = Initialized @transient var cpFile: Option[String] = None @transient var cpRDD: Option[RDD[T]] = None @transient var cpRDDSplits: Seq[Split] = Nil // Mark the RDD for checkpointing - def markForCheckpoint() = { - RDDCheckpointData.synchronized { cpState.mark() } + def markForCheckpoint() { + RDDCheckpointData.synchronized { + if (cpState == Initialized) cpState = MarkedForCheckpoint + } } // Is the RDD already checkpointed - def isCheckpointed() = { - RDDCheckpointData.synchronized { cpState.isCheckpointed } + def isCheckpointed(): Boolean = { + RDDCheckpointData.synchronized { cpState == Checkpointed } } - // Get the file to which this RDD was checkpointed to as a Option - def getCheckpointFile() = { + // Get the file to which this RDD was checkpointed to as an Option + def getCheckpointFile(): Option[String] = { RDDCheckpointData.synchronized { cpFile } } @@ -52,8 +48,8 @@ extends Logging with Serializable { // If it is marked for checkpointing AND checkpointing is not already in progress, // then set it to be in progress, else return RDDCheckpointData.synchronized { - if (cpState.isMarked && !cpState.isInProgress) { - cpState.start() + if (cpState == MarkedForCheckpoint) { + cpState = CheckpointingInProgress } else { return } @@ -87,7 +83,7 @@ extends Logging with Serializable { cpRDD = Some(newRDD) cpRDDSplits = newRDD.splits rdd.changeDependencies(newRDD) - cpState.finish() + cpState = Checkpointed RDDCheckpointData.checkpointCompleted() logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } -- cgit v1.2.3 From 5184141936c18f12c6738caae6fceee4d15800e2 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 18 Dec 2012 13:30:53 -0800 Subject: Introduced getSpits, getDependencies, and getPreferredLocations in RDD and RDDCheckpointData. --- core/src/main/scala/spark/PairRDDFunctions.scala | 4 +- core/src/main/scala/spark/ParallelCollection.scala | 9 +- core/src/main/scala/spark/RDD.scala | 123 +++++++++++++-------- core/src/main/scala/spark/RDDCheckpointData.scala | 10 +- core/src/main/scala/spark/rdd/BlockRDD.scala | 9 +- core/src/main/scala/spark/rdd/CartesianRDD.scala | 12 +- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 11 +- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 10 +- core/src/main/scala/spark/rdd/FilteredRDD.scala | 2 +- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 2 +- core/src/main/scala/spark/rdd/GlommedRDD.scala | 2 +- core/src/main/scala/spark/rdd/HadoopRDD.scala | 4 +- .../main/scala/spark/rdd/MapPartitionsRDD.scala | 2 +- .../spark/rdd/MapPartitionsWithSplitRDD.scala | 2 +- core/src/main/scala/spark/rdd/MappedRDD.scala | 2 +- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 4 +- core/src/main/scala/spark/rdd/PipedRDD.scala | 2 +- core/src/main/scala/spark/rdd/SampledRDD.scala | 9 +- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 7 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 13 +-- .../main/scala/spark/scheduler/DAGScheduler.scala | 2 +- core/src/test/scala/spark/CheckpointSuite.scala | 6 +- 22 files changed, 134 insertions(+), 113 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 1f82bd3ab8..09ac606cfb 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -628,7 +628,7 @@ private[spark] class MappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => U) extends RDD[(K, U)](prev.get) { - override def splits = firstParent[(K, V)].splits + override def getSplits = firstParent[(K, V)].splits override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split) = firstParent[(K, V)].iterator(split).map{case (k, v) => (k, f(v))} } @@ -637,7 +637,7 @@ private[spark] class FlatMappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) extends RDD[(K, U)](prev.get) { - override def splits = firstParent[(K, V)].splits + override def getSplits = firstParent[(K, V)].splits override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split) = { firstParent[(K, V)].iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9d12af6912..0bc5b2ff11 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -37,15 +37,12 @@ private[spark] class ParallelCollection[T: ClassManifest]( slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray } - override def splits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_.asInstanceOf[Array[Split]] override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator - - override def preferredLocations(s: Split): Seq[String] = Nil - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 6c04769c82..f3e422fa5f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -81,48 +81,33 @@ abstract class RDD[T: ClassManifest]( def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) - // Methods that must be implemented by subclasses: - - /** Set of partitions in this RDD. */ - def splits: Array[Split] + // ======================================================================= + // Methods that should be implemented by subclasses of RDD + // ======================================================================= /** Function for computing a given partition. */ def compute(split: Split): Iterator[T] - /** How this RDD depends on any parent RDDs. */ - def dependencies: List[Dependency[_]] = dependencies_ + /** Set of partitions in this RDD. */ + protected def getSplits(): Array[Split] - /** Record user function generating this RDD. */ - private[spark] val origin = Utils.getSparkCallSite - - /** Optionally overridden by subclasses to specify how they are partitioned. */ - val partitioner: Option[Partitioner] = None + /** How this RDD depends on any parent RDDs. */ + protected def getDependencies(): List[Dependency[_]] = dependencies_ /** Optionally overridden by subclasses to specify placement preferences. */ - def preferredLocations(split: Split): Seq[String] = Nil - - /** The [[spark.SparkContext]] that this RDD was created on. */ - def context = sc + protected def getPreferredLocations(split: Split): Seq[String] = Nil - private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] - - /** A unique ID for this RDD (within its SparkContext). */ - val id = sc.newRddId() - - // Variables relating to persistence - private var storageLevel: StorageLevel = StorageLevel.NONE + /** Optionally overridden by subclasses to specify how they are partitioned. */ + val partitioner: Option[Partitioner] = None - protected[spark] var checkpointData: Option[RDDCheckpointData[T]] = None - /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassManifest] = { - dependencies.head.rdd.asInstanceOf[RDD[U]] - } - /** Returns the `i` th parent RDD */ - protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] + // ======================================================================= + // Methods and fields available on all RDDs + // ======================================================================= - // Methods available on all RDDs: + /** A unique ID for this RDD (within its SparkContext). */ + val id = sc.newRddId() /** * Set this RDD's storage level to persist its values across operations after the first time @@ -147,11 +132,39 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - def getPreferredLocations(split: Split) = { + /** + * Get the preferred location of a split, taking into account whether the + * RDD is checkpointed or not. + */ + final def preferredLocations(split: Split): Seq[String] = { + if (isCheckpointed) { + checkpointData.get.getPreferredLocations(split) + } else { + getPreferredLocations(split) + } + } + + /** + * Get the array of splits of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def splits: Array[Split] = { + if (isCheckpointed) { + checkpointData.get.getSplits + } else { + getSplits + } + } + + /** + * Get the array of splits of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def dependencies: List[Dependency[_]] = { if (isCheckpointed) { - checkpointData.get.preferredLocations(split) + dependencies_ } else { - preferredLocations(split) + getDependencies } } @@ -536,6 +549,27 @@ abstract class RDD[T: ClassManifest]( if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None } + // ======================================================================= + // Other internal methods and fields + // ======================================================================= + + private var storageLevel: StorageLevel = StorageLevel.NONE + + /** Record user function generating this RDD. */ + private[spark] val origin = Utils.getSparkCallSite + + private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] + + private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None + + /** Returns the first parent RDD */ + protected[spark] def firstParent[U: ClassManifest] = { + dependencies.head.rdd.asInstanceOf[RDD[U]] + } + + /** The [[spark.SparkContext]] that this RDD was created on. */ + def context = sc + /** * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler * after a job using this RDD has completed (therefore the RDD has been materialized and @@ -548,23 +582,18 @@ abstract class RDD[T: ClassManifest]( /** * Changes the dependencies of this RDD from its original parents to the new RDD - * (`newRDD`) created from the checkpoint file. This method must ensure that all references - * to the original parent RDDs must be removed to enable the parent RDDs to be garbage - * collected. Subclasses of RDD may override this method for implementing their own changing - * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + * (`newRDD`) created from the checkpoint file. */ protected[spark] def changeDependencies(newRDD: RDD[_]) { + clearDependencies() dependencies_ = List(new OneToOneDependency(newRDD)) } - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - oos.defaultWriteObject() - } - - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream) { - ois.defaultReadObject() - } - + /** + * Clears the dependencies of this RDD. This method must ensure that all references + * to the original parent RDDs must be removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own changing + * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + */ + protected[spark] def clearDependencies() { } } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index 7613b338e6..e4c0912cdc 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -24,7 +24,6 @@ extends Logging with Serializable { var cpState = Initialized @transient var cpFile: Option[String] = None @transient var cpRDD: Option[RDD[T]] = None - @transient var cpRDDSplits: Seq[Split] = Nil // Mark the RDD for checkpointing def markForCheckpoint() { @@ -81,7 +80,6 @@ extends Logging with Serializable { RDDCheckpointData.synchronized { cpFile = Some(file) cpRDD = Some(newRDD) - cpRDDSplits = newRDD.splits rdd.changeDependencies(newRDD) cpState = Checkpointed RDDCheckpointData.checkpointCompleted() @@ -90,12 +88,18 @@ extends Logging with Serializable { } // Get preferred location of a split after checkpointing - def preferredLocations(split: Split) = { + def getPreferredLocations(split: Split) = { RDDCheckpointData.synchronized { cpRDD.get.preferredLocations(split) } } + def getSplits: Array[Split] = { + RDDCheckpointData.synchronized { + cpRDD.get.splits + } + } + // Get iterator. This is called at the worker nodes. def iterator(split: Split): Iterator[T] = { rdd.firstParent[T].iterator(split) diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index 0c8cdd10dd..68e570eb15 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -29,7 +29,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St HashMap(blockIds.zip(locations):_*) } - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split): Iterator[T] = { val blockManager = SparkEnv.get.blockManager @@ -41,12 +41,11 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St } } - override def preferredLocations(split: Split) = + override def getPreferredLocations(split: Split) = locations_(split.asInstanceOf[BlockRDDSplit].blockId) - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 9975e79b08..116644bd52 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -45,9 +45,9 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( array } - override def splits = splits_ + override def getSplits = splits_ - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { val currSplit = split.asInstanceOf[CartesianSplit] rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) } @@ -66,11 +66,11 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( } ) - override def dependencies = deps_ + override def getDependencies = deps_ - override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + deps_ = Nil + splits_ = null rdd1 = null rdd2 = null } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index bc6d16ee8b..9cc95dc172 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -65,9 +65,7 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) deps.toList } - // Pre-checkpoint dependencies deps_ should be transient (deps_) - // but post-checkpoint dependencies must not be transient (dependencies_) - override def dependencies = if (isCheckpointed) dependencies_ else deps_ + override def getDependencies = deps_ @transient var splits_ : Array[Split] = { @@ -85,7 +83,7 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) array } - override def splits = splits_ + override def getSplits = splits_ override val partitioner = Some(part) @@ -117,10 +115,9 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) map.iterator } - override def changeDependencies(newRDD: RDD[_]) { + override def clearDependencies() { deps_ = null - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + splits_ = null rdds = null } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 088958942e..85d0fa9f6a 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -44,7 +44,7 @@ class CoalescedRDD[T: ClassManifest]( } } - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { @@ -59,11 +59,11 @@ class CoalescedRDD[T: ClassManifest]( } ) - override def dependencies = deps_ + override def getDependencies() = deps_ - override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD)) - splits_ = newRDD.splits + override def clearDependencies() { + deps_ = Nil + splits_ = null prev = null } } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index 02f2e7c246..309ed2399d 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -9,6 +9,6 @@ class FilteredRDD[T: ClassManifest]( f: T => Boolean) extends RDD[T](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).filter(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index cdc8ecdcfe..1160e68bb8 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -9,6 +9,6 @@ class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( f: T => TraversableOnce[U]) extends RDD[U](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).flatMap(f) } diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index df6f61c69d..4fab1a56fa 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -6,6 +6,6 @@ import spark.Split private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index af54f23ebc..fce190b860 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -67,7 +67,7 @@ class HadoopRDD[K, V]( .asInstanceOf[InputFormat[K, V]] } - override def splits = splits_ + override def getSplits = splits_ override def compute(theSplit: Split) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[HadoopSplit] @@ -110,7 +110,7 @@ class HadoopRDD[K, V]( } } - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { // TODO: Filtering out "localhost" in case of file:// URLs val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index 23b9fb023b..5f4acee041 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -12,6 +12,6 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = f(firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index 41955c1d7a..f0f3f2c7c7 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -14,6 +14,6 @@ class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( f: (Int, Iterator[T]) => Iterator[U]) extends RDD[U](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = f(split.index, firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 6f8cb21fd3..44b542db93 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -9,6 +9,6 @@ class MappedRDD[U: ClassManifest, T: ClassManifest]( f: T => U) extends RDD[U](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).map(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index c12df5839e..91f89e3c75 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -55,7 +55,7 @@ class NewHadoopRDD[K, V]( result } - override def splits = splits_ + override def getSplits = splits_ override def compute(theSplit: Split) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopSplit] @@ -89,7 +89,7 @@ class NewHadoopRDD[K, V]( } } - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index d2047375ea..a88929e55e 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -29,7 +29,7 @@ class PipedRDD[T: ClassManifest]( // using a standard StringTokenizer (i.e. by spaces) def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split): Iterator[String] = { val pb = new ProcessBuilder(command) diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index c622e14a66..da6f65765c 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -26,9 +26,9 @@ class SampledRDD[T: ClassManifest]( firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } - override def splits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_.asInstanceOf[Array[Split]] - override def preferredLocations(split: Split) = + override def getPreferredLocations(split: Split) = firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) override def compute(splitIn: Split) = { @@ -51,8 +51,7 @@ class SampledRDD[T: ClassManifest]( } } - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index a9dd3f35ed..2caf33c21e 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -25,15 +25,14 @@ class ShuffledRDD[K, V]( @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index a84867492b..05ed6172d1 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -37,7 +37,7 @@ class UnionRDD[T: ClassManifest]( array } - override def splits = splits_ + override def getSplits = splits_ @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] @@ -49,19 +49,16 @@ class UnionRDD[T: ClassManifest]( deps.toList } - // Pre-checkpoint dependencies deps_ should be transient (deps_) - // but post-checkpoint dependencies must not be transient (dependencies_) - override def dependencies = if (isCheckpointed) dependencies_ else deps_ + override def getDependencies = deps_ override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() - override def preferredLocations(s: Split): Seq[String] = + override def getPreferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() - override def changeDependencies(newRDD: RDD[_]) { + override def clearDependencies() { deps_ = null - dependencies_ = List(new OneToOneDependency(newRDD)) - splits_ = newRDD.splits + splits_ = null rdds = null } } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 33d35b35d1..4b2570fa2b 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -575,7 +575,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those - val rddPrefs = rdd.getPreferredLocations(rdd.splits(partition)).toList + val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList if (rddPrefs != Nil) { return rddPrefs } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 0bffedb8db..19626d2450 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -57,7 +57,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.splits.length === numSplits) - assert(parCollection.splits.toList === parCollection.checkpointData.get.cpRDDSplits.toList) + assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList) assert(parCollection.collect() === result) } @@ -72,7 +72,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.splits.length === numSplits) - assert(blockRDD.splits.toList === blockRDD.checkpointData.get.cpRDDSplits.toList) + assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList) assert(blockRDD.collect() === result) } @@ -191,7 +191,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(operatedRDD.dependencies.head.rdd != parentRDD) // Test whether the splits have been changed to the new Hadoop splits - assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.cpRDDSplits.toList) + assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.getSplits.toList) // Test whether the number of splits is same as before assert(operatedRDD.splits.length === numSplits) -- cgit v1.2.3 From f9c5b0a6fe8d728e16c60c0cf51ced0054e3a387 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 20 Dec 2012 11:52:23 -0800 Subject: Changed checkpoint writing and reading process. --- core/src/main/scala/spark/RDDCheckpointData.scala | 27 +---- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 117 ++++++++++++++++++++++ core/src/main/scala/spark/rdd/HadoopRDD.scala | 5 +- 3 files changed, 124 insertions(+), 25 deletions(-) create mode 100644 core/src/main/scala/spark/rdd/CheckpointRDD.scala (limited to 'core') diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index e4c0912cdc..1aa9b9aa1e 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -1,7 +1,7 @@ package spark import org.apache.hadoop.fs.Path -import rdd.CoalescedRDD +import rdd.{CheckpointRDD, CoalescedRDD} import scheduler.{ResultTask, ShuffleMapTask} /** @@ -55,30 +55,13 @@ extends Logging with Serializable { } // Save to file, and reload it as an RDD - val file = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString - rdd.saveAsObjectFile(file) - - val newRDD = { - val hadoopRDD = rdd.context.objectFile[T](file, rdd.splits.size) - - val oldSplits = rdd.splits.size - val newSplits = hadoopRDD.splits.size - - logDebug("RDD splits = " + oldSplits + " --> " + newSplits) - if (newSplits < oldSplits) { - throw new Exception("# splits after checkpointing is less than before " + - "[" + oldSplits + " --> " + newSplits) - } else if (newSplits > oldSplits) { - new CoalescedRDD(hadoopRDD, rdd.splits.size) - } else { - hadoopRDD - } - } - logDebug("New RDD has " + newRDD.splits.size + " splits") + val path = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString + rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _) + val newRDD = new CheckpointRDD[T](rdd.context, path) // Change the dependencies and splits of the RDD RDDCheckpointData.synchronized { - cpFile = Some(file) + cpFile = Some(path) cpRDD = Some(newRDD) rdd.changeDependencies(newRDD) cpState = Checkpointed diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala new file mode 100644 index 0000000000..c673ab6aaa --- /dev/null +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -0,0 +1,117 @@ +package spark.rdd + +import spark._ +import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.{NullWritable, BytesWritable} +import org.apache.hadoop.util.ReflectionUtils +import org.apache.hadoop.fs.Path +import java.io.{File, IOException, EOFException} +import java.text.NumberFormat + +private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split { + override val index: Int = idx +} + +class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) + extends RDD[T](sc, Nil) { + + @transient val path = new Path(checkpointPath) + @transient val fs = path.getFileSystem(new Configuration()) + + @transient val splits_ : Array[Split] = { + val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted + splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray + } + + override def getSplits = splits_ + + override def getPreferredLocations(split: Split): Seq[String] = { + val status = fs.getFileStatus(path) + val locations = fs.getFileBlockLocations(status, 0, status.getLen) + locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost") + } + + override def compute(split: Split): Iterator[T] = { + CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile) + } + + override def checkpoint() { + // Do nothing. Hadoop RDD should not be checkpointed. + } +} + +private[spark] object CheckpointRDD extends Logging { + + def splitIdToFileName(splitId: Int): String = { + val numfmt = NumberFormat.getInstance() + numfmt.setMinimumIntegerDigits(5) + numfmt.setGroupingUsed(false) + "part-" + numfmt.format(splitId) + } + + def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) { + val outputDir = new Path(path) + val fs = outputDir.getFileSystem(new Configuration()) + + val finalOutputName = splitIdToFileName(context.splitId) + val finalOutputPath = new Path(outputDir, finalOutputName) + val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId) + + if (fs.exists(tempOutputPath)) { + throw new IOException("Checkpoint failed: temporary path " + + tempOutputPath + " already exists") + } + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + + val fileOutputStream = if (blockSize < 0) { + fs.create(tempOutputPath, false, bufferSize) + } else { + // This is mainly for testing purpose + fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + } + val serializer = SparkEnv.get.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + serializeStream.writeAll(iterator) + fileOutputStream.close() + + if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.delete(finalOutputPath, true)) { + throw new IOException("Checkpoint failed: failed to delete earlier output of task " + + context.attemptId); + } + if (!fs.rename(tempOutputPath, finalOutputPath)) { + throw new IOException("Checkpoint failed: failed to save output of task: " + + context.attemptId) + } + } + } + + def readFromFile[T](path: String): Iterator[T] = { + val inputPath = new Path(path) + val fs = inputPath.getFileSystem(new Configuration()) + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val fileInputStream = fs.open(inputPath, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + deserializeStream.asIterator.asInstanceOf[Iterator[T]] + } + + // Test whether CheckpointRDD generate expected number of splits despite + // each split file having multiple blocks. This needs to be run on a + // cluster (mesos or standalone) using HDFS. + def main(args: Array[String]) { + import spark._ + + val Array(cluster, hdfsPath) = args + val sc = new SparkContext(cluster, "CheckpointRDD Test") + val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) + val path = new Path(hdfsPath, "temp") + val fs = path.getFileSystem(new Configuration()) + sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 10) _) + val cpRDD = new CheckpointRDD[Int](sc, path.toString) + assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same") + assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same") + fs.delete(path) + } +} diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index fce190b860..eca51758e4 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -25,8 +25,7 @@ import spark.Split * A Spark split class that wraps around a Hadoop InputSplit. */ private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit) - extends Split - with Serializable { + extends Split { val inputSplit = new SerializableWritable[InputSplit](s) @@ -117,6 +116,6 @@ class HadoopRDD[K, V]( } override def checkpoint() { - // Do nothing. Hadoop RDD cannot be checkpointed. + // Do nothing. Hadoop RDD should not be checkpointed. } } -- cgit v1.2.3 From fe777eb77dee3c5bc5a7a332098d27f517ad3fe4 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 20 Dec 2012 13:39:27 -0800 Subject: Fixed bugs in CheckpointRDD and spark.CheckpointSuite. --- core/src/main/scala/spark/SparkContext.scala | 12 +++--------- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 3 +++ core/src/test/scala/spark/CheckpointSuite.scala | 6 +++--- 3 files changed, 9 insertions(+), 12 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 71ed4ef058..362aa04e66 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -37,9 +37,7 @@ import spark.broadcast._ import spark.deploy.LocalSparkCluster import spark.partial.ApproximateEvaluator import spark.partial.PartialResult -import spark.rdd.HadoopRDD -import spark.rdd.NewHadoopRDD -import spark.rdd.UnionRDD +import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD} import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} @@ -368,13 +366,9 @@ class SparkContext( protected[spark] def checkpointFile[T: ClassManifest]( - path: String, - minSplits: Int = defaultMinSplits + path: String ): RDD[T] = { - val rdd = objectFile[T](path, minSplits) - rdd.checkpointData = Some(new RDDCheckpointData(rdd)) - rdd.checkpointData.get.cpFile = Some(path) - rdd + new CheckpointRDD[T](this, path) } /** Build the union of a list of RDDs. */ diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index c673ab6aaa..fbf8a9ef83 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -24,6 +24,9 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray } + checkpointData = Some(new RDDCheckpointData[T](this)) + checkpointData.get.cpFile = Some(checkpointPath) + override def getSplits = splits_ override def getPreferredLocations(split: Split): Seq[String] = { diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 19626d2450..6bc667bd4c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -54,7 +54,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { parCollection.checkpoint() assert(parCollection.dependencies === Nil) val result = parCollection.collect() - assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) + assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.splits.length === numSplits) assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList) @@ -69,7 +69,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val numSplits = blockRDD.splits.size blockRDD.checkpoint() val result = blockRDD.collect() - assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) + assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.splits.length === numSplits) assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList) @@ -185,7 +185,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) // Test whether the checkpoint file has been created - assert(sc.objectFile[U](operatedRDD.getCheckpointFile.get).collect() === result) + assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result) // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD) -- cgit v1.2.3 From 0bc0a60d3001dd231e13057a838d4b6550e5a2b9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 27 Dec 2012 15:37:33 -0800 Subject: Modifications to make sure LocalScheduler terminate cleanly without errors when SparkContext is shutdown, to minimize spurious exception during master failure tests. --- core/src/main/scala/spark/SparkContext.scala | 22 ++++++++++++---------- .../spark/scheduler/local/LocalScheduler.scala | 8 ++++++-- core/src/test/resources/log4j.properties | 2 +- .../src/test/scala/spark/ClosureCleanerSuite.scala | 2 ++ streaming/src/test/resources/log4j.properties | 13 ++++++++----- 5 files changed, 29 insertions(+), 18 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index caa9a1794b..0c8b0078a3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -488,17 +488,19 @@ class SparkContext( if (dagScheduler != null) { dagScheduler.stop() dagScheduler = null + taskScheduler = null + // TODO: Cache.stop()? + env.stop() + // Clean up locally linked files + clearFiles() + clearJars() + SparkEnv.set(null) + ShuffleMapTask.clearCache() + ResultTask.clearCache() + logInfo("Successfully stopped SparkContext") + } else { + logInfo("SparkContext already stopped") } - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - // Clean up locally linked files - clearFiles() - clearJars() - SparkEnv.set(null) - ShuffleMapTask.clearCache() - ResultTask.clearCache() - logInfo("Successfully stopped SparkContext") } /** diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index eb20fe41b2..17a0a4b103 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -81,7 +81,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( ser.serialize(Accumulators.values)) logInfo("Finished task " + idInJob) - listener.taskEnded(task, Success, resultToReturn, accumUpdates) + + // If the threadpool has not already been shutdown, notify DAGScheduler + if (!Thread.currentThread().isInterrupted) + listener.taskEnded(task, Success, resultToReturn, accumUpdates) } catch { case t: Throwable => { logError("Exception in task " + idInJob, t) @@ -91,7 +94,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon submitTask(task, idInJob) } else { // TODO: Do something nicer here to return all the way to the user - listener.taskEnded(task, new ExceptionFailure(t), null, null) + if (!Thread.currentThread().isInterrupted) + listener.taskEnded(task, new ExceptionFailure(t), null, null) } } } diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index 4c99e450bc..5ed388e91b 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -1,4 +1,4 @@ -# Set everything to be logged to the console +# Set everything to be logged to the file spark-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala index 7c0334d957..dfa2de80e6 100644 --- a/core/src/test/scala/spark/ClosureCleanerSuite.scala +++ b/core/src/test/scala/spark/ClosureCleanerSuite.scala @@ -47,6 +47,8 @@ object TestObject { val nums = sc.parallelize(Array(1, 2, 3, 4)) val answer = nums.map(_ + x).reduce(_ + _) sc.stop() + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") return answer } } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 02fe16866e..33bafebaab 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -1,8 +1,11 @@ -# Set everything to be logged to the console -log4j.rootCategory=WARN, console -log4j.appender.console=org.apache.log4j.ConsoleAppender -log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +# Set everything to be logged to the file streaming-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=streaming-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN + -- cgit v1.2.3 From 9e644402c155b5fc68794a17c36ddd19d3242f4f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 29 Dec 2012 18:31:51 -0800 Subject: Improved jekyll and scala docs. Made many classes and method private to remove them from scala docs. --- core/src/main/scala/spark/RDD.scala | 1 - docs/_plugins/copy_api_dirs.rb | 4 +- docs/streaming-programming-guide.md | 56 ++--- .../main/scala/spark/streaming/Checkpoint.scala | 5 +- .../src/main/scala/spark/streaming/DStream.scala | 249 +++++++++++++-------- .../scala/spark/streaming/FlumeInputDStream.scala | 2 +- .../src/main/scala/spark/streaming/Interval.scala | 1 + streaming/src/main/scala/spark/streaming/Job.scala | 2 + .../main/scala/spark/streaming/JobManager.scala | 1 + .../spark/streaming/NetworkInputDStream.scala | 8 +- .../spark/streaming/NetworkInputTracker.scala | 8 +- .../spark/streaming/PairDStreamFunctions.scala | 4 +- .../src/main/scala/spark/streaming/Scheduler.scala | 7 +- .../scala/spark/streaming/StreamingContext.scala | 43 ++-- .../scala/spark/streaming/examples/GrepRaw.scala | 2 +- .../streaming/examples/TopKWordCountRaw.scala | 2 +- .../spark/streaming/examples/WordCountRaw.scala | 2 +- .../examples/clickstream/PageViewStream.scala | 2 +- .../test/scala/spark/streaming/TestSuiteBase.scala | 2 +- 19 files changed, 233 insertions(+), 168 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 59e50a0b6b..1574533430 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -101,7 +101,6 @@ abstract class RDD[T: ClassManifest]( val partitioner: Option[Partitioner] = None - // ======================================================================= // Methods and fields available on all RDDs // ======================================================================= diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index e61c105449..7654511eeb 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -2,7 +2,7 @@ require 'fileutils' include FileUtils if ENV['SKIP_SCALADOC'] != '1' - projects = ["core", "examples", "repl", "bagel"] + projects = ["core", "examples", "repl", "bagel", "streaming"] puts "Moving to project root and building scaladoc." curr_dir = pwd @@ -11,7 +11,7 @@ if ENV['SKIP_SCALADOC'] != '1' puts "Running sbt/sbt doc from " + pwd + "; this may take a few minutes..." puts `sbt/sbt doc` - puts "moving back into docs dir." + puts "Moving back into docs dir." cd("docs") # Copy over the scaladoc from each project into the docs directory. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 90916545bc..7c421ac70f 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2,33 +2,44 @@ layout: global title: Streaming (Alpha) Programming Guide --- + +{:toc} + +# Overview +A Spark Streaming application is very similar to a Spark application; it consists of a *driver program* that runs the user's `main` function and continuous executes various *parallel operations* on input streams of data. The main abstraction Spark Streaming provides is a *discretized stream* (DStream), which is a continuous sequence of RDDs (distributed collection of elements) representing a continuous stream of data. DStreams can created from live incoming data (such as data from a socket, Kafka, etc.) or it can be generated by transformation of existing DStreams using parallel operators like map, reduce, and window. The basic processing model is as follows: +(i) While a Spark Streaming driver program is running, the system receives data from various sources and and divides the data into batches. Each batch of data is treated as a RDD, that is a immutable and parallel collection of data. These input data RDDs are automatically persisted in memory (serialized by default) and replicated to two nodes for fault-tolerance. This sequence of RDDs is collectively referred to as an InputDStream. +(ii) Data received by InputDStreams are processed processed using DStream operations. Since all data is represented as RDDs and all DStream operations as RDD operations, data is automatically recovered in the event of node failures. + +This guide shows some how to start programming with DStreams. + # Initializing Spark Streaming The first thing a Spark Streaming program must do is create a `StreamingContext` object, which tells Spark how to access a cluster. A `StreamingContext` can be created from an existing `SparkContext`, or directly: {% highlight scala %} -new StreamingContext(master, jobName, [sparkHome], [jars]) -new StreamingContext(sparkContext) -{% endhighlight %} - -Once a context is instantiated, the batch interval must be set: +import spark.SparkContext +import SparkContext._ -{% highlight scala %} -context.setBatchDuration(Milliseconds(2000)) +new StreamingContext(master, frameworkName, batchDuration) +new StreamingContext(sparkContext, batchDuration) {% endhighlight %} +The `master` parameter is either the [Mesos master URL](running-on-mesos.html) (for running on a cluster)or the special "local" string (for local mode) that is used to create a Spark Context. For more information about this please refer to the [Spark programming guide](scala-programming-guide.html). -# DStreams - Discretized Streams -The primary abstraction in Spark Streaming is a DStream. A DStream represents distributed collection which is computed periodically according to a specified batch interval. DStream's can be chained together to create complex chains of transformation on streaming data. DStreams can be created by operating on existing DStreams or from an input source. To creating DStreams from an input source, use the StreamingContext: + +# Creating Input Sources - InputDStreams +The StreamingContext is used to creating InputDStreams from input sources: {% highlight scala %} -context.neworkStream(host, port) // A stream that reads from a socket -context.flumeStream(hosts, ports) // A stream populated by a Flume flow +context.neworkStream(host, port) // Creates a stream that uses a TCP socket to read data from : +context.flumeStream(host, port) // Creates a stream populated by a Flume flow {% endhighlight %} -# DStream Operators +A complete list of input sources is available in the [DStream API doc](api/streaming/index.html#spark.streaming.StreamingContext). + +## DStream Operations Once an input stream has been created, you can transform it using _stream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the stream by writing data out to an external source. -## Transformations +### Transformations DStreams support many of the transformations available on normal Spark RDD's: @@ -73,20 +84,13 @@ DStreams support many of the transformations available on normal Spark RDD's: cogroup(otherStream, [numTasks]) When called on streams of type (K, V) and (K, W), returns a stream of (K, Seq[V], Seq[W]) tuples. This operation is also called groupWith. - - -DStreams also support the following additional transformations: - -
reduce(func) Create a new single-element stream by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel.
- -## Windowed Transformations -Spark streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowTime, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. +Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowTime, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. @@ -128,7 +132,7 @@ Spark streaming features windowed computations, which allow you to report statis
TransformationMeaning
-## Output Operators +### Output Operators When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined: @@ -140,22 +144,22 @@ When an output operator is called, it triggers the computation of a stream. Curr - + - + - + - + diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 770f7b0cc0..11a7232d7b 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -8,6 +8,7 @@ import org.apache.hadoop.conf.Configuration import java.io._ +private[streaming] class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Logging with Serializable { val master = ssc.sc.master @@ -30,6 +31,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) /** * Convenience class to speed up the writing of graph checkpoint to file */ +private[streaming] class CheckpointWriter(checkpointDir: String) extends Logging { val file = new Path(checkpointDir, "graph") val conf = new Configuration() @@ -65,7 +67,7 @@ class CheckpointWriter(checkpointDir: String) extends Logging { } - +private[streaming] object CheckpointReader extends Logging { def read(path: String): Checkpoint = { @@ -103,6 +105,7 @@ object CheckpointReader extends Logging { } } +private[streaming] class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) extends ObjectInputStream(inputStream_) { override def resolveClass(desc: ObjectStreamClass): Class[_] = { try { diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index d5048aeed7..3834b57ed3 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.conf.Configuration * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous * sequence of RDDs (of the same type) representing a continuous stream of data (see [[spark.RDD]] * for more details on RDDs). DStreams can either be created from live data (such as, data from - * HDFS. Kafka or Flume) or it can be generated by transformation existing DStreams using operations + * HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations * such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each * DStream periodically generates a RDD, either from live data or by transforming the RDD generated * by a parent DStream. @@ -38,33 +38,28 @@ import org.apache.hadoop.conf.Configuration * - A function that is used to generate an RDD after each time interval */ -case class DStreamCheckpointData(rdds: HashMap[Time, Any]) - -abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) -extends Serializable with Logging { +abstract class DStream[T: ClassManifest] ( + @transient protected[streaming] var ssc: StreamingContext + ) extends Serializable with Logging { initLogging() - /** - * ---------------------------------------------- - * Methods that must be implemented by subclasses - * ---------------------------------------------- - */ + // ======================================================================= + // Methods that should be implemented by subclasses of DStream + // ======================================================================= - // Time interval at which the DStream generates an RDD + /** Time interval after which the DStream generates a RDD */ def slideTime: Time - // List of parent DStreams on which this DStream depends on + /** List of parent DStreams on which this DStream depends on */ def dependencies: List[DStream[_]] - // Key method that computes RDD for a valid time + /** Method that generates a RDD for the given time */ def compute (validTime: Time): Option[RDD[T]] - /** - * --------------------------------------- - * Other general fields and methods of DStream - * --------------------------------------- - */ + // ======================================================================= + // Methods and fields available on all DStreams + // ======================================================================= // RDDs generated, marked as protected[streaming] so that testsuites can access it @transient @@ -87,12 +82,15 @@ extends Serializable with Logging { // Reference to whole DStream graph protected[streaming] var graph: DStreamGraph = null - def isInitialized = (zeroTime != null) + protected[streaming] def isInitialized = (zeroTime != null) // Duration for which the DStream requires its parent DStream to remember each RDD created - def parentRememberDuration = rememberDuration + protected[streaming] def parentRememberDuration = rememberDuration + + /** Returns the StreamingContext associated with this DStream */ + def context() = ssc - // Set caching level for the RDDs created by this DStream + /** Persists the RDDs of this DStream with the given storage level */ def persist(level: StorageLevel): DStream[T] = { if (this.isInitialized) { throw new UnsupportedOperationException( @@ -102,11 +100,16 @@ extends Serializable with Logging { this } + /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY_SER) - - // Turn on the default caching level for this RDD + + /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def cache(): DStream[T] = persist() + /** + * Enable periodic checkpointing of RDDs of this DStream + * @param interval Time interval after which generated RDD will be checkpointed + */ def checkpoint(interval: Time): DStream[T] = { if (isInitialized) { throw new UnsupportedOperationException( @@ -285,7 +288,7 @@ extends Serializable with Logging { * Generates a SparkStreaming job for the given time. This is an internal method that * should not be called directly. This default implementation creates a job * that materializes the corresponding RDD. Subclasses of DStream may override this - * (eg. PerRDDForEachDStream). + * (eg. ForEachDStream). */ protected[streaming] def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { @@ -420,65 +423,96 @@ extends Serializable with Logging { generatedRDDs = new HashMap[Time, RDD[T]] () } - /** - * -------------- - * DStream operations - * -------------- - */ + // ======================================================================= + // DStream operations + // ======================================================================= + + /** Returns a new DStream by applying a function to all elements of this DStream. */ def map[U: ClassManifest](mapFunc: T => U): DStream[U] = { new MappedDStream(this, ssc.sc.clean(mapFunc)) } + /** + * Returns a new DStream by applying a function to all elements of this DStream, + * and then flattening the results + */ def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = { new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) } + /** Returns a new DStream containing only the elements that satisfy a predicate. */ def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc) + /** + * Return a new DStream in which each RDD is generated by applying glom() to each RDD of + * this DStream. Applying glom() to an RDD coalesces all elements within each partition into + * an array. + */ def glom(): DStream[Array[T]] = new GlommedDStream(this) - def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]): DStream[U] = { - new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc)) + /** + * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs + * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition + * of the RDD. + */ + def mapPartitions[U: ClassManifest]( + mapPartFunc: Iterator[T] => Iterator[U], + preservePartitioning: Boolean = false + ): DStream[U] = { + new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc), preservePartitioning) } - def reduce(reduceFunc: (T, T) => T): DStream[T] = this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) + /** + * Returns a new DStream in which each RDD has a single element generated by reducing each RDD + * of this DStream. + */ + def reduce(reduceFunc: (T, T) => T): DStream[T] = + this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) + /** + * Returns a new DStream in which each RDD has a single element generated by counting each RDD + * of this DStream. + */ def count(): DStream[Int] = this.map(_ => 1).reduce(_ + _) - - def collect(): DStream[Seq[T]] = this.map(x => (null, x)).groupByKey(1).map(_._2) - - def foreach(foreachFunc: T => Unit) { - val newStream = new PerElementForEachDStream(this, ssc.sc.clean(foreachFunc)) - ssc.registerOutputStream(newStream) - newStream - } - def foreachRDD(foreachFunc: RDD[T] => Unit) { - foreachRDD((r: RDD[T], t: Time) => foreachFunc(r)) + /** + * Applies a function to each RDD in this DStream. This is an output operator, so + * this DStream will be registered as an output stream and therefore materialized. + */ + def foreach(foreachFunc: RDD[T] => Unit) { + foreach((r: RDD[T], t: Time) => foreachFunc(r)) } - def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) { - val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc)) + /** + * Applies a function to each RDD in this DStream. This is an output operator, so + * this DStream will be registered as an output stream and therefore materialized. + */ + def foreach(foreachFunc: (RDD[T], Time) => Unit) { + val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc)) ssc.registerOutputStream(newStream) newStream } - def transformRDD[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = { - transformRDD((r: RDD[T], t: Time) => transformFunc(r)) + /** + * Returns a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = { + transform((r: RDD[T], t: Time) => transformFunc(r)) } - def transformRDD[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { + /** + * Returns a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { new TransformedDStream(this, ssc.sc.clean(transformFunc)) } - def toBlockingQueue() = { - val queue = new ArrayBlockingQueue[RDD[T]](10000) - this.foreachRDD(rdd => { - queue.add(rdd) - }) - queue - } - + /** + * Prints the first ten elements of each RDD generated in this DStream. This is an output + * operator, so this DStream will be registered as an output stream and there materialized. + */ def print() { def foreachFunc = (rdd: RDD[T], time: Time) => { val first11 = rdd.take(11) @@ -489,18 +523,42 @@ extends Serializable with Logging { if (first11.size > 10) println("...") println() } - val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc)) + val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc)) ssc.registerOutputStream(newStream) } + /** + * Return a new DStream which is computed based on windowed batches of this DStream. + * The new DStream generates RDDs with the same interval as this DStream. + * @param windowTime width of the window; must be a multiple of this DStream's interval. + * @return + */ def window(windowTime: Time): DStream[T] = window(windowTime, this.slideTime) + /** + * Return a new DStream which is computed based on windowed batches of this DStream. + * @param windowTime duration (i.e., width) of the window; + * must be a multiple of this DStream's interval + * @param slideTime sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's interval + */ def window(windowTime: Time, slideTime: Time): DStream[T] = { new WindowedDStream(this, windowTime, slideTime) } + /** + * Returns a new DStream which computed based on tumbling window on this DStream. + * This is equivalent to window(batchTime, batchTime). + * @param batchTime tumbling window duration; must be a multiple of this DStream's interval + */ def tumble(batchTime: Time): DStream[T] = window(batchTime, batchTime) + /** + * Returns a new DStream in which each RDD has a single element generated by reducing all + * elements in a window over this DStream. windowTime and slideTime are as defined in the + * window() operation. This is equivalent to window(windowTime, slideTime).reduce(reduceFunc) + */ def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time): DStream[T] = { this.window(windowTime, slideTime).reduce(reduceFunc) } @@ -516,17 +574,31 @@ extends Serializable with Logging { .map(_._2) } + /** + * Returns a new DStream in which each RDD has a single element generated by counting the number + * of elements in a window over this DStream. windowTime and slideTime are as defined in the + * window() operation. This is equivalent to window(windowTime, slideTime).count() + */ def countByWindow(windowTime: Time, slideTime: Time): DStream[Int] = { this.map(_ => 1).reduceByWindow(_ + _, _ - _, windowTime, slideTime) } + /** + * Returns a new DStream by unifying data of another DStream with this DStream. + * @param that Another DStream having the same interval (i.e., slideTime) as this DStream. + */ def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that)) - def slice(interval: Interval): Seq[RDD[T]] = { + /** + * Returns all the RDDs defined by the Interval object (both end times included) + */ + protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = { slice(interval.beginTime, interval.endTime) } - // Get all the RDDs between fromTime to toTime (both included) + /** + * Returns all the RDDs between 'fromTime' to 'toTime' (both included) + */ def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() var time = toTime.floor(slideTime) @@ -540,20 +612,26 @@ extends Serializable with Logging { rdds.toSeq } + /** + * Saves each RDD in this DStream as a Sequence file of serialized objects. + */ def saveAsObjectFiles(prefix: String, suffix: String = "") { val saveFunc = (rdd: RDD[T], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsObjectFile(file) } - this.foreachRDD(saveFunc) + this.foreach(saveFunc) } + /** + * Saves each RDD in this DStream as at text file, using string representation of elements. + */ def saveAsTextFiles(prefix: String, suffix: String = "") { val saveFunc = (rdd: RDD[T], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(file) } - this.foreachRDD(saveFunc) + this.foreach(saveFunc) } def register() { @@ -561,6 +639,8 @@ extends Serializable with Logging { } } +private[streaming] +case class DStreamCheckpointData(rdds: HashMap[Time, Any]) abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) extends DStream[T](ssc_) { @@ -583,6 +663,7 @@ abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContex * TODO */ +private[streaming] class MappedDStream[T: ClassManifest, U: ClassManifest] ( parent: DStream[T], mapFunc: T => U @@ -602,6 +683,7 @@ class MappedDStream[T: ClassManifest, U: ClassManifest] ( * TODO */ +private[streaming] class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( parent: DStream[T], flatMapFunc: T => Traversable[U] @@ -621,6 +703,7 @@ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( * TODO */ +private[streaming] class FilteredDStream[T: ClassManifest]( parent: DStream[T], filterFunc: T => Boolean @@ -640,9 +723,11 @@ class FilteredDStream[T: ClassManifest]( * TODO */ +private[streaming] class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( parent: DStream[T], - mapPartFunc: Iterator[T] => Iterator[U] + mapPartFunc: Iterator[T] => Iterator[U], + preservePartitioning: Boolean ) extends DStream[U](parent.ssc) { override def dependencies = List(parent) @@ -650,7 +735,7 @@ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( override def slideTime: Time = parent.slideTime override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc)) + parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning)) } } @@ -659,6 +744,7 @@ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( * TODO */ +private[streaming] class GlommedDStream[T: ClassManifest](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { @@ -676,6 +762,7 @@ class GlommedDStream[T: ClassManifest](parent: DStream[T]) * TODO */ +private[streaming] class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( parent: DStream[(K,V)], createCombiner: V => C, @@ -702,6 +789,7 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( * TODO */ +private[streaming] class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( parent: DStream[(K, V)], mapValueFunc: V => U @@ -720,7 +808,7 @@ class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( /** * TODO */ - +private[streaming] class FlatMapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( parent: DStream[(K, V)], flatMapValueFunc: V => TraversableOnce[U] @@ -779,38 +867,8 @@ class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) * TODO */ -class PerElementForEachDStream[T: ClassManifest] ( - parent: DStream[T], - foreachFunc: T => Unit - ) extends DStream[Unit](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Unit]] = None - - override def generateJob(time: Time): Option[Job] = { - parent.getOrCompute(time) match { - case Some(rdd) => - val jobFunc = () => { - val sparkJobFunc = { - (iterator: Iterator[T]) => iterator.foreach(foreachFunc) - } - ssc.sc.runJob(rdd, sparkJobFunc) - } - Some(new Job(time, jobFunc)) - case None => None - } - } -} - - -/** - * TODO - */ - -class PerRDDForEachDStream[T: ClassManifest] ( +private[streaming] +class ForEachDStream[T: ClassManifest] ( parent: DStream[T], foreachFunc: (RDD[T], Time) => Unit ) extends DStream[Unit](parent.ssc) { @@ -838,6 +896,7 @@ class PerRDDForEachDStream[T: ClassManifest] ( * TODO */ +private[streaming] class TransformedDStream[T: ClassManifest, U: ClassManifest] ( parent: DStream[T], transformFunc: (RDD[T], Time) => RDD[U] diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala index 2959ce4540..5ac7e5b08e 100644 --- a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala @@ -79,7 +79,7 @@ class SparkFlumeEvent() extends Externalizable { } } -object SparkFlumeEvent { +private[streaming] object SparkFlumeEvent { def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { val event = new SparkFlumeEvent event.event = in diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala index ffb7725ac9..fa0b7ce19d 100644 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -1,5 +1,6 @@ package spark.streaming +private[streaming] case class Interval(beginTime: Time, endTime: Time) { def this(beginMs: Long, endMs: Long) = this(Time(beginMs), new Time(endMs)) diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala index 0bcb6fd8dc..67bd8388bc 100644 --- a/streaming/src/main/scala/spark/streaming/Job.scala +++ b/streaming/src/main/scala/spark/streaming/Job.scala @@ -2,6 +2,7 @@ package spark.streaming import java.util.concurrent.atomic.AtomicLong +private[streaming] class Job(val time: Time, func: () => _) { val id = Job.getNewId() def run(): Long = { @@ -14,6 +15,7 @@ class Job(val time: Time, func: () => _) { override def toString = "streaming job " + id + " @ " + time } +private[streaming] object Job { val id = new AtomicLong(0) diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 9bf9251519..fda7264a27 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -5,6 +5,7 @@ import spark.SparkEnv import java.util.concurrent.Executors +private[streaming] class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { class JobHandler(ssc: StreamingContext, job: Job) extends Runnable { diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index 4e4e9fc942..4bf13dd50c 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -40,10 +40,10 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming } -sealed trait NetworkReceiverMessage -case class StopReceiver(msg: String) extends NetworkReceiverMessage -case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage -case class ReportError(msg: String) extends NetworkReceiverMessage +private[streaming] sealed trait NetworkReceiverMessage +private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage +private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage +private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging { diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index b421f795ee..658498dfc1 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -11,10 +11,10 @@ import akka.pattern.ask import akka.util.duration._ import akka.dispatch._ -trait NetworkInputTrackerMessage -case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage -case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage -case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage +private[streaming] sealed trait NetworkInputTrackerMessage +private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage +private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage +private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage class NetworkInputTracker( diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 720e63bba0..f9fef14196 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -281,7 +281,7 @@ extends Serializable { val file = rddToFileName(prefix, suffix, time) rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) } - self.foreachRDD(saveFunc) + self.foreach(saveFunc) } def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]]( @@ -303,7 +303,7 @@ extends Serializable { val file = rddToFileName(prefix, suffix, time) rdd.saveAsNewAPIHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) } - self.foreachRDD(saveFunc) + self.foreach(saveFunc) } private def getKeyClass() = implicitly[ClassManifest[K]].erasure diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 014021be61..fd1fa77a24 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -7,11 +7,8 @@ import spark.Logging import scala.collection.mutable.HashMap -sealed trait SchedulerMessage -case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage - -class Scheduler(ssc: StreamingContext) -extends Logging { +private[streaming] +class Scheduler(ssc: StreamingContext) extends Logging { initLogging() diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index ce47bcb2da..998fea849f 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -48,7 +48,7 @@ class StreamingContext private ( this(StreamingContext.createNewSparkContext(master, frameworkName), null, batchDuration) /** - * Recreates the StreamingContext from a checkpoint file. + * Re-creates a StreamingContext from a checkpoint file. * @param path Path either to the directory that was specified as the checkpoint directory, or * to the checkpoint file 'graph' or 'graph.bk'. */ @@ -61,7 +61,7 @@ class StreamingContext private ( "both SparkContext and checkpoint as null") } - val isCheckpointPresent = (cp_ != null) + protected[streaming] val isCheckpointPresent = (cp_ != null) val sc: SparkContext = { if (isCheckpointPresent) { @@ -71,9 +71,9 @@ class StreamingContext private ( } } - val env = SparkEnv.get + protected[streaming] val env = SparkEnv.get - val graph: DStreamGraph = { + protected[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { cp_.graph.setContext(this) cp_.graph.restoreCheckpointData() @@ -86,10 +86,10 @@ class StreamingContext private ( } } - private[streaming] val nextNetworkInputStreamId = new AtomicInteger(0) - private[streaming] var networkInputTracker: NetworkInputTracker = null + protected[streaming] val nextNetworkInputStreamId = new AtomicInteger(0) + protected[streaming] var networkInputTracker: NetworkInputTracker = null - private[streaming] var checkpointDir: String = { + protected[streaming] var checkpointDir: String = { if (isCheckpointPresent) { sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(cp_.checkpointDir), true) cp_.checkpointDir @@ -98,9 +98,9 @@ class StreamingContext private ( } } - private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null - private[streaming] var receiverJobThread: Thread = null - private[streaming] var scheduler: Scheduler = null + protected[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null + protected[streaming] var receiverJobThread: Thread = null + protected[streaming] var scheduler: Scheduler = null def remember(duration: Time) { graph.remember(duration) @@ -117,11 +117,11 @@ class StreamingContext private ( } } - private[streaming] def getInitialCheckpoint(): Checkpoint = { + protected[streaming] def getInitialCheckpoint(): Checkpoint = { if (isCheckpointPresent) cp_ else null } - private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() + protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() /** * Create an input stream that pulls messages form a Kafka Broker. @@ -188,7 +188,7 @@ class StreamingContext private ( } /** - * This function creates a input stream that monitors a Hadoop-compatible filesystem + * Creates a input stream that monitors a Hadoop-compatible filesystem * for new files and executes the necessary processing on them. */ def fileStream[ @@ -206,7 +206,7 @@ class StreamingContext private ( } /** - * This function create a input stream from an queue of RDDs. In each batch, + * Creates a input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue */ def queueStream[T: ClassManifest]( @@ -231,22 +231,21 @@ class StreamingContext private ( } /** - * This function registers a InputDStream as an input stream that will be - * started (InputDStream.start() called) to get the input data streams. + * Registers an input stream that will be started (InputDStream.start() called) to get the + * input data. */ def registerInputStream(inputStream: InputDStream[_]) { graph.addInputStream(inputStream) } /** - * This function registers a DStream as an output stream that will be - * computed every interval. + * Registers an output stream that will be computed every interval */ def registerOutputStream(outputStream: DStream[_]) { graph.addOutputStream(outputStream) } - def validate() { + protected def validate() { assert(graph != null, "Graph is null") graph.validate() @@ -304,7 +303,7 @@ class StreamingContext private ( object StreamingContext { - def createNewSparkContext(master: String, frameworkName: String): SparkContext = { + protected[streaming] def createNewSparkContext(master: String, frameworkName: String): SparkContext = { // Set the default cleaner delay to an hour if not already set. // This should be sufficient for even 1 second interval. @@ -318,7 +317,7 @@ object StreamingContext { new PairDStreamFunctions[K, V](stream) } - def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { + protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { if (prefix == null) { time.millis.toString } else if (suffix == null || suffix.length ==0) { @@ -328,7 +327,7 @@ object StreamingContext { } } - def getSparkCheckpointDir(sscCheckpointDir: String): String = { + protected[streaming] def getSparkCheckpointDir(sscCheckpointDir: String): String = { new Path(sscCheckpointDir, UUID.randomUUID.toString).toString } } diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala index 6cb2b4c042..7c4ee3b34c 100644 --- a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -26,7 +26,7 @@ object GrepRaw { val rawStreams = (1 to numStreams).map(_ => ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray val union = new UnionDStream(rawStreams) - union.filter(_.contains("Alice")).count().foreachRDD(r => + union.filter(_.contains("Alice")).count().foreach(r => println("Grep count: " + r.collect().mkString)) ssc.start() } diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index fe4c2bf155..182dfd8a52 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -38,7 +38,7 @@ object TopKWordCountRaw { val counts = union.mapPartitions(splitAndCountPartitions) val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { + partialTopKWindowedCounts.foreach(rdd => { val collectedCounts = rdd.collect println("Collected " + collectedCounts.size + " words from partial top words") println("Top " + k + " words are " + topK(collectedCounts.toIterator, k).mkString(",")) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index a29c81d437..9bcd30f4d7 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -36,7 +36,7 @@ object WordCountRaw { val union = new UnionDStream(lines.toArray) val counts = union.mapPartitions(splitAndCountPartitions) val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) - windowedCounts.foreachRDD(r => println("# unique words = " + r.count())) + windowedCounts.foreach(r => println("# unique words = " + r.count())) ssc.start() } diff --git a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala index 68be6b7893..a191321d91 100644 --- a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala +++ b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala @@ -72,7 +72,7 @@ object PageViewStream { case "popularUsersSeen" => // Look for users in our existing dataset and print it out if we have a match pageViews.map(view => (view.userID, 1)) - .foreachRDD((rdd, time) => rdd.join(userList) + .foreach((rdd, time) => rdd.join(userList) .map(_._2._2) .take(10) .foreach(u => println("Saw user %s at time %s".format(u, time)))) diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index 8cc2f8ccfc..a44f738957 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -35,7 +35,7 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. */ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) - extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { + extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected }) { -- cgit v1.2.3 From 7e0271b4387eaf27cd96f3057ce2465b1271a480 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 30 Dec 2012 15:19:55 -0800 Subject: Refactored a whole lot to push all DStreams into the spark.streaming.dstream package. --- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 1 + .../scala/spark/streaming/CoGroupedDStream.scala | 38 --- .../spark/streaming/ConstantInputDStream.scala | 18 -- .../src/main/scala/spark/streaming/DStream.scala | 276 +-------------------- .../main/scala/spark/streaming/DStreamGraph.scala | 1 + .../main/scala/spark/streaming/DataHandler.scala | 83 ------- .../scala/spark/streaming/FileInputDStream.scala | 109 -------- .../scala/spark/streaming/FlumeInputDStream.scala | 130 ---------- .../spark/streaming/NetworkInputDStream.scala | 156 ------------ .../spark/streaming/NetworkInputTracker.scala | 2 + .../spark/streaming/PairDStreamFunctions.scala | 7 +- .../scala/spark/streaming/QueueInputDStream.scala | 40 --- .../scala/spark/streaming/RawInputDStream.scala | 85 ------- .../spark/streaming/ReducedWindowedDStream.scala | 149 ----------- .../src/main/scala/spark/streaming/Scheduler.scala | 3 - .../scala/spark/streaming/SocketInputDStream.scala | 107 -------- .../main/scala/spark/streaming/StateDStream.scala | 84 ------- .../scala/spark/streaming/StreamingContext.scala | 13 +- .../src/main/scala/spark/streaming/Time.scala | 11 +- .../scala/spark/streaming/WindowedDStream.scala | 39 --- .../spark/streaming/dstream/CoGroupedDStream.scala | 39 +++ .../streaming/dstream/ConstantInputDStream.scala | 19 ++ .../spark/streaming/dstream/DataHandler.scala | 83 +++++++ .../spark/streaming/dstream/FileInputDStream.scala | 110 ++++++++ .../spark/streaming/dstream/FilteredDStream.scala | 21 ++ .../streaming/dstream/FlatMapValuedDStream.scala | 20 ++ .../streaming/dstream/FlatMappedDStream.scala | 20 ++ .../streaming/dstream/FlumeInputDStream.scala | 135 ++++++++++ .../spark/streaming/dstream/ForEachDStream.scala | 28 +++ .../spark/streaming/dstream/GlommedDStream.scala | 17 ++ .../spark/streaming/dstream/InputDStream.scala | 19 ++ .../streaming/dstream/KafkaInputDStream.scala | 197 +++++++++++++++ .../streaming/dstream/MapPartitionedDStream.scala | 21 ++ .../spark/streaming/dstream/MapValuedDStream.scala | 21 ++ .../spark/streaming/dstream/MappedDStream.scala | 20 ++ .../streaming/dstream/NetworkInputDStream.scala | 157 ++++++++++++ .../streaming/dstream/QueueInputDStream.scala | 41 +++ .../spark/streaming/dstream/RawInputDStream.scala | 88 +++++++ .../streaming/dstream/ReducedWindowedDStream.scala | 148 +++++++++++ .../spark/streaming/dstream/ShuffledDStream.scala | 27 ++ .../streaming/dstream/SocketInputDStream.scala | 103 ++++++++ .../spark/streaming/dstream/StateDStream.scala | 83 +++++++ .../streaming/dstream/TransformedDStream.scala | 19 ++ .../spark/streaming/dstream/UnionDStream.scala | 39 +++ .../spark/streaming/dstream/WindowedDStream.scala | 40 +++ .../scala/spark/streaming/examples/GrepRaw.scala | 2 +- .../streaming/examples/TopKWordCountRaw.scala | 2 +- .../spark/streaming/examples/WordCountRaw.scala | 2 +- .../spark/streaming/input/KafkaInputDStream.scala | 193 -------------- .../scala/spark/streaming/CheckpointSuite.scala | 2 +- .../test/scala/spark/streaming/FailureSuite.scala | 2 +- .../scala/spark/streaming/InputStreamsSuite.scala | 1 + .../test/scala/spark/streaming/TestSuiteBase.scala | 48 +++- .../spark/streaming/WindowOperationsSuite.scala | 12 +- 54 files changed, 1600 insertions(+), 1531 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DataHandler.scala delete mode 100644 streaming/src/main/scala/spark/streaming/FileInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/QueueInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/RawInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/SocketInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/StateDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/WindowedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/DataHandler.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index f40b56be64..1b219473e0 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -1,6 +1,7 @@ package spark.rdd import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext} +import spark.SparkContext._ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { override val index = idx diff --git a/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala deleted file mode 100644 index 61d088eddb..0000000000 --- a/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala +++ /dev/null @@ -1,38 +0,0 @@ -package spark.streaming - -import spark.{RDD, Partitioner} -import spark.rdd.CoGroupedRDD - -class CoGroupedDStream[K : ClassManifest]( - parents: Seq[DStream[(_, _)]], - partitioner: Partitioner - ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) { - - if (parents.length == 0) { - throw new IllegalArgumentException("Empty array of parents") - } - - if (parents.map(_.ssc).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different StreamingContexts") - } - - if (parents.map(_.slideTime).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different slide times") - } - - override def dependencies = parents.toList - - override def slideTime = parents.head.slideTime - - override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = { - val part = partitioner - val rdds = parents.flatMap(_.getOrCompute(validTime)) - if (rdds.size > 0) { - val q = new CoGroupedRDD[K](rdds, part) - Some(q) - } else { - None - } - } - -} diff --git a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala deleted file mode 100644 index 80150708fd..0000000000 --- a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala +++ /dev/null @@ -1,18 +0,0 @@ -package spark.streaming - -import spark.RDD - -/** - * An input stream that always returns the same RDD on each timestep. Useful for testing. - */ -class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T]) - extends InputDStream[T](ssc_) { - - override def start() {} - - override def stop() {} - - override def compute(validTime: Time): Option[RDD[T]] = { - Some(rdd) - } -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 3834b57ed3..292ad3b9f9 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -1,17 +1,15 @@ package spark.streaming +import spark.streaming.dstream._ import StreamingContext._ import Time._ -import spark._ -import spark.SparkContext._ -import spark.rdd._ +import spark.{RDD, Logging} import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import java.util.concurrent.ArrayBlockingQueue import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import org.apache.hadoop.fs.Path @@ -197,7 +195,7 @@ abstract class DStream[T: ClassManifest] ( "than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" + "delay is set to " + (metadataCleanerDelay / 60.0) + " minutes, which is not sufficient. Please set " + "the Java property 'spark.cleaner.delay' to more than " + - math.ceil(rememberDuration.millis.toDouble / 60000.0).toInt + " minutes." + math.ceil(rememberDuration.milliseconds.toDouble / 60000.0).toInt + " minutes." ) dependencies.foreach(_.validate()) @@ -642,271 +640,3 @@ abstract class DStream[T: ClassManifest] ( private[streaming] case class DStreamCheckpointData(rdds: HashMap[Time, Any]) -abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) - extends DStream[T](ssc_) { - - override def dependencies = List() - - override def slideTime = { - if (ssc == null) throw new Exception("ssc is null") - if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null") - ssc.graph.batchDuration - } - - def start() - - def stop() -} - - -/** - * TODO - */ - -private[streaming] -class MappedDStream[T: ClassManifest, U: ClassManifest] ( - parent: DStream[T], - mapFunc: T => U - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.map[U](mapFunc)) - } -} - - -/** - * TODO - */ - -private[streaming] -class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( - parent: DStream[T], - flatMapFunc: T => Traversable[U] - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) - } -} - - -/** - * TODO - */ - -private[streaming] -class FilteredDStream[T: ClassManifest]( - parent: DStream[T], - filterFunc: T => Boolean - ) extends DStream[T](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[T]] = { - parent.getOrCompute(validTime).map(_.filter(filterFunc)) - } -} - - -/** - * TODO - */ - -private[streaming] -class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( - parent: DStream[T], - mapPartFunc: Iterator[T] => Iterator[U], - preservePartitioning: Boolean - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning)) - } -} - - -/** - * TODO - */ - -private[streaming] -class GlommedDStream[T: ClassManifest](parent: DStream[T]) - extends DStream[Array[T]](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Array[T]]] = { - parent.getOrCompute(validTime).map(_.glom()) - } -} - - -/** - * TODO - */ - -private[streaming] -class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( - parent: DStream[(K,V)], - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - partitioner: Partitioner - ) extends DStream [(K,C)] (parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[(K,C)]] = { - parent.getOrCompute(validTime) match { - case Some(rdd) => - Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner)) - case None => None - } - } -} - - -/** - * TODO - */ - -private[streaming] -class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( - parent: DStream[(K, V)], - mapValueFunc: V => U - ) extends DStream[(K, U)](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[(K, U)]] = { - parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc)) - } -} - - -/** - * TODO - */ -private[streaming] -class FlatMapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( - parent: DStream[(K, V)], - flatMapValueFunc: V => TraversableOnce[U] - ) extends DStream[(K, U)](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[(K, U)]] = { - parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc)) - } -} - - - -/** - * TODO - */ - -class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) - extends DStream[T](parents.head.ssc) { - - if (parents.length == 0) { - throw new IllegalArgumentException("Empty array of parents") - } - - if (parents.map(_.ssc).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different StreamingContexts") - } - - if (parents.map(_.slideTime).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different slide times") - } - - override def dependencies = parents.toList - - override def slideTime: Time = parents.head.slideTime - - override def compute(validTime: Time): Option[RDD[T]] = { - val rdds = new ArrayBuffer[RDD[T]]() - parents.map(_.getOrCompute(validTime)).foreach(_ match { - case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) - }) - if (rdds.size > 0) { - Some(new UnionRDD(ssc.sc, rdds)) - } else { - None - } - } -} - - -/** - * TODO - */ - -private[streaming] -class ForEachDStream[T: ClassManifest] ( - parent: DStream[T], - foreachFunc: (RDD[T], Time) => Unit - ) extends DStream[Unit](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Unit]] = None - - override def generateJob(time: Time): Option[Job] = { - parent.getOrCompute(time) match { - case Some(rdd) => - val jobFunc = () => { - foreachFunc(rdd, time) - } - Some(new Job(time, jobFunc)) - case None => None - } - } -} - - -/** - * TODO - */ - -private[streaming] -class TransformedDStream[T: ClassManifest, U: ClassManifest] ( - parent: DStream[T], - transformFunc: (RDD[T], Time) => RDD[U] - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(transformFunc(_, validTime)) - } -} diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index d0a9ade61d..c72429370e 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -1,5 +1,6 @@ package spark.streaming +import dstream.InputDStream import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import collection.mutable.ArrayBuffer import spark.Logging diff --git a/streaming/src/main/scala/spark/streaming/DataHandler.scala b/streaming/src/main/scala/spark/streaming/DataHandler.scala deleted file mode 100644 index 05f307a8d1..0000000000 --- a/streaming/src/main/scala/spark/streaming/DataHandler.scala +++ /dev/null @@ -1,83 +0,0 @@ -package spark.streaming - -import java.util.concurrent.ArrayBlockingQueue -import scala.collection.mutable.ArrayBuffer -import spark.Logging -import spark.streaming.util.{RecurringTimer, SystemClock} -import spark.storage.StorageLevel - - -/** - * This is a helper object that manages the data received from the socket. It divides - * the object received into small batches of 100s of milliseconds, pushes them as - * blocks into the block manager and reports the block IDs to the network input - * tracker. It starts two threads, one to periodically start a new batch and prepare - * the previous batch of as a block, the other to push the blocks into the block - * manager. - */ - class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel) - extends Serializable with Logging { - - case class Block(id: String, iterator: Iterator[T], metadata: Any = null) - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockStorageLevel = storageLevel - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def createBlock(blockId: String, iterator: Iterator[T]) : Block = { - new Block(blockId, iterator) - } - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - logInfo("Data handler stopped") - } - - def += (obj: T) { - currentBuffer += obj - } - - def updateCurrentBuffer(time: Long) { - try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval) - val newBlock = createBlock(blockId, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - } - } catch { - case ie: InterruptedException => - logInfo("Block interval timer thread interrupted") - case e: Exception => - receiver.stop() - } - } - - def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel) - } - } catch { - case ie: InterruptedException => - logInfo("Block pushing thread interrupted") - case e: Exception => - receiver.stop() - } - } - } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala deleted file mode 100644 index 88856364d2..0000000000 --- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala +++ /dev/null @@ -1,109 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.rdd.UnionRDD - -import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} - -import scala.collection.mutable.HashSet - - -class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( - @transient ssc_ : StreamingContext, - directory: String, - filter: PathFilter = FileInputDStream.defaultPathFilter, - newFilesOnly: Boolean = true) - extends InputDStream[(K, V)](ssc_) { - - @transient private var path_ : Path = null - @transient private var fs_ : FileSystem = null - - var lastModTime = 0L - val lastModTimeFiles = new HashSet[String]() - - def path(): Path = { - if (path_ == null) path_ = new Path(directory) - path_ - } - - def fs(): FileSystem = { - if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) - fs_ - } - - override def start() { - if (newFilesOnly) { - lastModTime = System.currentTimeMillis() - } else { - lastModTime = 0 - } - } - - override def stop() { } - - /** - * Finds the files that were modified since the last time this method was called and makes - * a union RDD out of them. Note that this maintains the list of files that were processed - * in the latest modification time in the previous call to this method. This is because the - * modification time returned by the FileStatus API seems to return times only at the - * granularity of seconds. Hence, new files may have the same modification time as the - * latest modification time in the previous call to this method and the list of files - * maintained is used to filter the one that have been processed. - */ - override def compute(validTime: Time): Option[RDD[(K, V)]] = { - // Create the filter for selecting new files - val newFilter = new PathFilter() { - var latestModTime = 0L - val latestModTimeFiles = new HashSet[String]() - - def accept(path: Path): Boolean = { - if (!filter.accept(path)) { - return false - } else { - val modTime = fs.getFileStatus(path).getModificationTime() - if (modTime < lastModTime){ - return false - } else if (modTime == lastModTime && lastModTimeFiles.contains(path.toString)) { - return false - } - if (modTime > latestModTime) { - latestModTime = modTime - latestModTimeFiles.clear() - } - latestModTimeFiles += path.toString - return true - } - } - } - - val newFiles = fs.listStatus(path, newFilter) - logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) - if (newFiles.length > 0) { - // Update the modification time and the files processed for that modification time - if (lastModTime != newFilter.latestModTime) { - lastModTime = newFilter.latestModTime - lastModTimeFiles.clear() - } - lastModTimeFiles ++= newFilter.latestModTimeFiles - } - val newRDD = new UnionRDD(ssc.sc, newFiles.map( - file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) - Some(newRDD) - } -} - -object FileInputDStream { - val defaultPathFilter = new PathFilter with Serializable { - def accept(path: Path): Boolean = { - val file = path.getName() - if (file.startsWith(".") || file.endsWith("_tmp")) { - return false - } else { - return true - } - } - } -} - diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala deleted file mode 100644 index 5ac7e5b08e..0000000000 --- a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala +++ /dev/null @@ -1,130 +0,0 @@ -package spark.streaming - -import java.io.{ObjectInput, ObjectOutput, Externalizable} -import spark.storage.StorageLevel -import org.apache.flume.source.avro.AvroSourceProtocol -import org.apache.flume.source.avro.AvroFlumeEvent -import org.apache.flume.source.avro.Status -import org.apache.avro.ipc.specific.SpecificResponder -import org.apache.avro.ipc.NettyServer -import java.net.InetSocketAddress -import collection.JavaConversions._ -import spark.Utils -import java.nio.ByteBuffer - -class FlumeInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - storageLevel: StorageLevel -) extends NetworkInputDStream[SparkFlumeEvent](ssc_) { - - override def createReceiver(): NetworkReceiver[SparkFlumeEvent] = { - new FlumeReceiver(id, host, port, storageLevel) - } -} - -/** - * A wrapper class for AvroFlumeEvent's with a custom serialization format. - * - * This is necessary because AvroFlumeEvent uses inner data structures - * which are not serializable. - */ -class SparkFlumeEvent() extends Externalizable { - var event : AvroFlumeEvent = new AvroFlumeEvent() - - /* De-serialize from bytes. */ - def readExternal(in: ObjectInput) { - val bodyLength = in.readInt() - val bodyBuff = new Array[Byte](bodyLength) - in.read(bodyBuff) - - val numHeaders = in.readInt() - val headers = new java.util.HashMap[CharSequence, CharSequence] - - for (i <- 0 until numHeaders) { - val keyLength = in.readInt() - val keyBuff = new Array[Byte](keyLength) - in.read(keyBuff) - val key : String = Utils.deserialize(keyBuff) - - val valLength = in.readInt() - val valBuff = new Array[Byte](valLength) - in.read(valBuff) - val value : String = Utils.deserialize(valBuff) - - headers.put(key, value) - } - - event.setBody(ByteBuffer.wrap(bodyBuff)) - event.setHeaders(headers) - } - - /* Serialize to bytes. */ - def writeExternal(out: ObjectOutput) { - val body = event.getBody.array() - out.writeInt(body.length) - out.write(body) - - val numHeaders = event.getHeaders.size() - out.writeInt(numHeaders) - for ((k, v) <- event.getHeaders) { - val keyBuff = Utils.serialize(k.toString) - out.writeInt(keyBuff.length) - out.write(keyBuff) - val valBuff = Utils.serialize(v.toString) - out.writeInt(valBuff.length) - out.write(valBuff) - } - } -} - -private[streaming] object SparkFlumeEvent { - def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { - val event = new SparkFlumeEvent - event.event = in - event - } -} - -/** A simple server that implements Flume's Avro protocol. */ -class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { - override def append(event : AvroFlumeEvent) : Status = { - receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event) - Status.OK - } - - override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { - events.foreach (event => - receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event)) - Status.OK - } -} - -/** A NetworkReceiver which listens for events using the - * Flume Avro interface.*/ -class FlumeReceiver( - streamId: Int, - host: String, - port: Int, - storageLevel: StorageLevel - ) extends NetworkReceiver[SparkFlumeEvent](streamId) { - - lazy val dataHandler = new DataHandler(this, storageLevel) - - protected override def onStart() { - val responder = new SpecificResponder( - classOf[AvroSourceProtocol], new FlumeEventServer(this)); - val server = new NettyServer(responder, new InetSocketAddress(host, port)); - dataHandler.start() - server.start() - logInfo("Flume receiver started") - } - - protected override def onStop() { - dataHandler.stop() - logInfo("Flume receiver stopped") - } - - override def getLocationPreference = Some(host) -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala deleted file mode 100644 index 4bf13dd50c..0000000000 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ /dev/null @@ -1,156 +0,0 @@ -package spark.streaming - -import scala.collection.mutable.ArrayBuffer - -import spark.{Logging, SparkEnv, RDD} -import spark.rdd.BlockRDD -import spark.streaming.util.{RecurringTimer, SystemClock} -import spark.storage.StorageLevel - -import java.nio.ByteBuffer - -import akka.actor.{Props, Actor} -import akka.pattern.ask -import akka.dispatch.Await -import akka.util.duration._ - -abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext) - extends InputDStream[T](ssc_) { - - // This is an unique identifier that is used to match the network receiver with the - // corresponding network input stream. - val id = ssc.getNewNetworkStreamId() - - /** - * This method creates the receiver object that will be sent to the workers - * to receive data. This method needs to defined by any specific implementation - * of a NetworkInputDStream. - */ - def createReceiver(): NetworkReceiver[T] - - // Nothing to start or stop as both taken care of by the NetworkInputTracker. - def start() {} - - def stop() {} - - override def compute(validTime: Time): Option[RDD[T]] = { - val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) - Some(new BlockRDD[T](ssc.sc, blockIds)) - } -} - - -private[streaming] sealed trait NetworkReceiverMessage -private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage -private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage -private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage - -abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging { - - initLogging() - - lazy protected val env = SparkEnv.get - - lazy protected val actor = env.actorSystem.actorOf( - Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) - - lazy protected val receivingThread = Thread.currentThread() - - /** This method will be called to start receiving data. */ - protected def onStart() - - /** This method will be called to stop receiving data. */ - protected def onStop() - - /** This method conveys a placement preference (hostname) for this receiver. */ - def getLocationPreference() : Option[String] = None - - /** - * This method starts the receiver. First is accesses all the lazy members to - * materialize them. Then it calls the user-defined onStart() method to start - * other threads, etc required to receiver the data. - */ - def start() { - try { - // Access the lazy vals to materialize them - env - actor - receivingThread - - // Call user-defined onStart() - onStart() - } catch { - case ie: InterruptedException => - logInfo("Receiving thread interrupted") - //println("Receiving thread interrupted") - case e: Exception => - stopOnError(e) - } - } - - /** - * This method stops the receiver. First it interrupts the main receiving thread, - * that is, the thread that called receiver.start(). Then it calls the user-defined - * onStop() method to stop other threads and/or do cleanup. - */ - def stop() { - receivingThread.interrupt() - onStop() - //TODO: terminate the actor - } - - /** - * This method stops the receiver and reports to exception to the tracker. - * This should be called whenever an exception has happened on any thread - * of the receiver. - */ - protected def stopOnError(e: Exception) { - logError("Error receiving data", e) - stop() - actor ! ReportError(e.toString) - } - - - /** - * This method pushes a block (as iterator of values) into the block manager. - */ - def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) { - val buffer = new ArrayBuffer[T] ++ iterator - env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level) - - actor ! ReportBlock(blockId, metadata) - } - - /** - * This method pushes a block (as bytes) into the block manager. - */ - def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { - env.blockManager.putBytes(blockId, bytes, level) - actor ! ReportBlock(blockId, metadata) - } - - /** A helper actor that communicates with the NetworkInputTracker */ - private class NetworkReceiverActor extends Actor { - logInfo("Attempting to register with tracker") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) - val tracker = env.actorSystem.actorFor(url) - val timeout = 5.seconds - - override def preStart() { - val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) - } - - override def receive() = { - case ReportBlock(blockId, metadata) => - tracker ! AddBlocks(streamId, Array(blockId), metadata) - case ReportError(msg) => - tracker ! DeregisterReceiver(streamId, msg) - case StopReceiver(msg) => - stop() - tracker ! DeregisterReceiver(streamId, msg) - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 658498dfc1..a6ab44271f 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -1,5 +1,7 @@ package spark.streaming +import spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} +import spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} import spark.Logging import spark.SparkEnv diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index f9fef14196..b0a208e67f 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -1,6 +1,9 @@ package spark.streaming import spark.streaming.StreamingContext._ +import spark.streaming.dstream.{ReducedWindowedDStream, StateDStream} +import spark.streaming.dstream.{CoGroupedDStream, ShuffledDStream} +import spark.streaming.dstream.{MapValuedDStream, FlatMapValuedDStream} import spark.{Manifests, RDD, Partitioner, HashPartitioner} import spark.SparkContext._ @@ -218,13 +221,13 @@ extends Serializable { def mapValues[U: ClassManifest](mapValuesFunc: V => U): DStream[(K, U)] = { - new MapValuesDStream[K, V, U](self, mapValuesFunc) + new MapValuedDStream[K, V, U](self, mapValuesFunc) } def flatMapValues[U: ClassManifest]( flatMapValuesFunc: V => TraversableOnce[U] ): DStream[(K, U)] = { - new FlatMapValuesDStream[K, V, U](self, flatMapValuesFunc) + new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc) } def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { diff --git a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala deleted file mode 100644 index bb86e51932..0000000000 --- a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala +++ /dev/null @@ -1,40 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.rdd.UnionRDD - -import scala.collection.mutable.Queue -import scala.collection.mutable.ArrayBuffer - -class QueueInputDStream[T: ClassManifest]( - @transient ssc: StreamingContext, - val queue: Queue[RDD[T]], - oneAtATime: Boolean, - defaultRDD: RDD[T] - ) extends InputDStream[T](ssc) { - - override def start() { } - - override def stop() { } - - override def compute(validTime: Time): Option[RDD[T]] = { - val buffer = new ArrayBuffer[RDD[T]]() - if (oneAtATime && queue.size > 0) { - buffer += queue.dequeue() - } else { - buffer ++= queue - } - if (buffer.size > 0) { - if (oneAtATime) { - Some(buffer.first) - } else { - Some(new UnionRDD(ssc.sc, buffer.toSeq)) - } - } else if (defaultRDD != null) { - Some(defaultRDD) - } else { - None - } - } - -} diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala deleted file mode 100644 index 6acaa9aab1..0000000000 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ /dev/null @@ -1,85 +0,0 @@ -package spark.streaming - -import java.net.InetSocketAddress -import java.nio.ByteBuffer -import java.nio.channels.{ReadableByteChannel, SocketChannel} -import java.io.EOFException -import java.util.concurrent.ArrayBlockingQueue -import spark._ -import spark.storage.StorageLevel - -/** - * An input stream that reads blocks of serialized objects from a given network address. - * The blocks will be inserted directly into the block store. This is the fastest way to get - * data into Spark Streaming, though it requires the sender to batch data and serialize it - * in the format that the system is configured with. - */ -class RawInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_ ) with Logging { - - def createReceiver(): NetworkReceiver[T] = { - new RawNetworkReceiver(id, host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] - } -} - -class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel) - extends NetworkReceiver[Any](streamId) { - - var blockPushingThread: Thread = null - - override def getLocationPreference = None - - def onStart() { - // Open a socket to the target address and keep reading from it - logInfo("Connecting to " + host + ":" + port) - val channel = SocketChannel.open() - channel.configureBlocking(true) - channel.connect(new InetSocketAddress(host, port)) - logInfo("Connected to " + host + ":" + port) - - val queue = new ArrayBlockingQueue[ByteBuffer](2) - - blockPushingThread = new DaemonThread { - override def run() { - var nextBlockNumber = 0 - while (true) { - val buffer = queue.take() - val blockId = "input-" + streamId + "-" + nextBlockNumber - nextBlockNumber += 1 - pushBlock(blockId, buffer, null, storageLevel) - } - } - } - blockPushingThread.start() - - val lengthBuffer = ByteBuffer.allocate(4) - while (true) { - lengthBuffer.clear() - readFully(channel, lengthBuffer) - lengthBuffer.flip() - val length = lengthBuffer.getInt() - val dataBuffer = ByteBuffer.allocate(length) - readFully(channel, dataBuffer) - dataBuffer.flip() - logInfo("Read a block with " + length + " bytes") - queue.put(dataBuffer) - } - } - - def onStop() { - if (blockPushingThread != null) blockPushingThread.interrupt() - } - - /** Read a buffer fully from a given Channel */ - private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { - while (dest.position < dest.limit) { - if (channel.read(dest) == -1) { - throw new EOFException("End of channel") - } - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala deleted file mode 100644 index f63a9e0011..0000000000 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ /dev/null @@ -1,149 +0,0 @@ -package spark.streaming - -import spark.streaming.StreamingContext._ - -import spark.RDD -import spark.rdd.UnionRDD -import spark.rdd.CoGroupedRDD -import spark.Partitioner -import spark.SparkContext._ -import spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer -import collection.SeqProxy - -class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( - parent: DStream[(K, V)], - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - _windowTime: Time, - _slideTime: Time, - partitioner: Partitioner - ) extends DStream[(K,V)](parent.ssc) { - - assert(_windowTime.isMultipleOf(parent.slideTime), - "The window duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" - ) - - assert(_slideTime.isMultipleOf(parent.slideTime), - "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" - ) - - // Reduce each batch of data using reduceByKey which will be further reduced by window - // by ReducedWindowedDStream - val reducedStream = parent.reduceByKey(reduceFunc, partitioner) - - // Persist RDDs to memory by default as these RDDs are going to be reused. - super.persist(StorageLevel.MEMORY_ONLY_SER) - reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) - - def windowTime: Time = _windowTime - - override def dependencies = List(reducedStream) - - override def slideTime: Time = _slideTime - - override val mustCheckpoint = true - - override def parentRememberDuration: Time = rememberDuration + windowTime - - override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { - super.persist(storageLevel) - reducedStream.persist(storageLevel) - this - } - - override def checkpoint(interval: Time): DStream[(K, V)] = { - super.checkpoint(interval) - //reducedStream.checkpoint(interval) - this - } - - override def compute(validTime: Time): Option[RDD[(K, V)]] = { - val reduceF = reduceFunc - val invReduceF = invReduceFunc - - val currentTime = validTime - val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime) - val previousWindow = currentWindow - slideTime - - logDebug("Window time = " + windowTime) - logDebug("Slide time = " + slideTime) - logDebug("ZeroTime = " + zeroTime) - logDebug("Current window = " + currentWindow) - logDebug("Previous window = " + previousWindow) - - // _____________________________ - // | previous window _________|___________________ - // |___________________| current window | --------------> Time - // |_____________________________| - // - // |________ _________| |________ _________| - // | | - // V V - // old RDDs new RDDs - // - - // Get the RDDs of the reduced values in "old time steps" - val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideTime) - logDebug("# old RDDs = " + oldRDDs.size) - - // Get the RDDs of the reduced values in "new time steps" - val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideTime, currentWindow.endTime) - logDebug("# new RDDs = " + newRDDs.size) - - // Get the RDD of the reduced value of the previous window - val previousWindowRDD = getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) - - // Make the list of RDDs that needs to cogrouped together for reducing their reduced values - val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs - - // Cogroup the reduced RDDs and merge the reduced values - val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) - //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ - - val numOldValues = oldRDDs.size - val numNewValues = newRDDs.size - - val mergeValues = (seqOfValues: Seq[Seq[V]]) => { - if (seqOfValues.size != 1 + numOldValues + numNewValues) { - throw new Exception("Unexpected number of sequences of reduced values") - } - // Getting reduced values "old time steps" that will be removed from current window - val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) - // Getting reduced values "new time steps" - val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) - if (seqOfValues(0).isEmpty) { - // If previous window's reduce value does not exist, then at least new values should exist - if (newValues.isEmpty) { - throw new Exception("Neither previous window has value for key, nor new values found. " + - "Are you sure your key class hashes consistently?") - } - // Reduce the new values - newValues.reduce(reduceF) // return - } else { - // Get the previous window's reduced value - var tempValue = seqOfValues(0).head - // If old values exists, then inverse reduce then from previous value - if (!oldValues.isEmpty) { - tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) - } - // If new values exists, then reduce them with previous value - if (!newValues.isEmpty) { - tempValue = reduceF(tempValue, newValues.reduce(reduceF)) - } - tempValue // return - } - } - - val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) - - Some(mergedValuesRDD) - } - - -} - - diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index fd1fa77a24..aeb7c3eb0e 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -4,9 +4,6 @@ import util.{ManualClock, RecurringTimer, Clock} import spark.SparkEnv import spark.Logging -import scala.collection.mutable.HashMap - - private[streaming] class Scheduler(ssc: StreamingContext) extends Logging { diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala deleted file mode 100644 index a9e37c0ff0..0000000000 --- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala +++ /dev/null @@ -1,107 +0,0 @@ -package spark.streaming - -import spark.streaming.util.{RecurringTimer, SystemClock} -import spark.storage.StorageLevel - -import java.io._ -import java.net.Socket -import java.util.concurrent.ArrayBlockingQueue - -import scala.collection.mutable.ArrayBuffer -import scala.Serializable - -class SocketInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - bytesToObjects: InputStream => Iterator[T], - storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_) { - - def createReceiver(): NetworkReceiver[T] = { - new SocketReceiver(id, host, port, bytesToObjects, storageLevel) - } -} - - -class SocketReceiver[T: ClassManifest]( - streamId: Int, - host: String, - port: Int, - bytesToObjects: InputStream => Iterator[T], - storageLevel: StorageLevel - ) extends NetworkReceiver[T](streamId) { - - lazy protected val dataHandler = new DataHandler(this, storageLevel) - - override def getLocationPreference = None - - protected def onStart() { - logInfo("Connecting to " + host + ":" + port) - val socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) - dataHandler.start() - val iterator = bytesToObjects(socket.getInputStream()) - while(iterator.hasNext) { - val obj = iterator.next - dataHandler += obj - } - } - - protected def onStop() { - dataHandler.stop() - } - -} - - -object SocketReceiver { - - /** - * This methods translates the data from an inputstream (say, from a socket) - * to '\n' delimited strings and returns an iterator to access the strings. - */ - def bytesToLines(inputStream: InputStream): Iterator[String] = { - val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")) - - val iterator = new Iterator[String] { - var gotNext = false - var finished = false - var nextValue: String = null - - private def getNext() { - try { - nextValue = dataInputStream.readLine() - if (nextValue == null) { - finished = true - } - } - gotNext = true - } - - override def hasNext: Boolean = { - if (!finished) { - if (!gotNext) { - getNext() - if (finished) { - dataInputStream.close() - } - } - } - !finished - } - - override def next(): String = { - if (finished) { - throw new NoSuchElementException("End of stream") - } - if (!gotNext) { - getNext() - } - gotNext = false - nextValue - } - } - iterator - } -} diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala deleted file mode 100644 index b7e4c1c30c..0000000000 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ /dev/null @@ -1,84 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.rdd.BlockRDD -import spark.Partitioner -import spark.rdd.MapPartitionsRDD -import spark.SparkContext._ -import spark.storage.StorageLevel - -class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( - parent: DStream[(K, V)], - updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], - partitioner: Partitioner, - preservePartitioning: Boolean - ) extends DStream[(K, S)](parent.ssc) { - - super.persist(StorageLevel.MEMORY_ONLY_SER) - - override def dependencies = List(parent) - - override def slideTime = parent.slideTime - - override val mustCheckpoint = true - - override def compute(validTime: Time): Option[RDD[(K, S)]] = { - - // Try to get the previous state RDD - getOrCompute(validTime - slideTime) match { - - case Some(prevStateRDD) => { // If previous state RDD exists - - // Try to get the parent RDD - parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual - - // Define the function for the mapPartition operation on cogrouped RDD; - // first map the cogrouped tuple to tuples of required type, - // and then apply the update function - val updateFuncLocal = updateFunc - val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { - val i = iterator.map(t => { - (t._1, t._2._1, t._2._2.headOption) - }) - updateFuncLocal(i) - } - val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) - val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) - //logDebug("Generating state RDD for time " + validTime) - return Some(stateRDD) - } - case None => { // If parent RDD does not exist, then return old state RDD - return Some(prevStateRDD) - } - } - } - - case None => { // If previous session RDD does not exist (first input data) - - // Try to get the parent RDD - parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual - - // Define the function for the mapPartition operation on grouped RDD; - // first map the grouped tuple to tuples of required type, - // and then apply the update function - val updateFuncLocal = updateFunc - val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { - updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None))) - } - - val groupedRDD = parentRDD.groupByKey(partitioner) - val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) - //logDebug("Generating state RDD for time " + validTime + " (first)") - return Some(sessionRDD) - } - case None => { // If parent RDD does not exist, then nothing to do! - //logDebug("Not generating state RDD (no previous state, no parent)") - return None - } - } - } - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 998fea849f..ef73049a81 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -1,10 +1,10 @@ package spark.streaming -import spark.RDD -import spark.Logging -import spark.SparkEnv -import spark.SparkContext +import spark.streaming.dstream._ + +import spark.{RDD, Logging, SparkEnv, SparkContext} import spark.storage.StorageLevel +import spark.util.MetadataCleaner import scala.collection.mutable.Queue @@ -18,7 +18,6 @@ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.flume.source.avro.AvroFlumeEvent import org.apache.hadoop.fs.Path import java.util.UUID -import spark.util.MetadataCleaner /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -126,7 +125,7 @@ class StreamingContext private ( /** * Create an input stream that pulls messages form a Kafka Broker. * - * @param host Zookeper hostname. + * @param hostname Zookeper hostname. * @param port Zookeper port. * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed @@ -319,7 +318,7 @@ object StreamingContext { protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { if (prefix == null) { - time.millis.toString + time.milliseconds.toString } else if (suffix == null || suffix.length ==0) { prefix + "-" + time.milliseconds } else { diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 480d292d7c..2976e5e87b 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -1,6 +1,11 @@ package spark.streaming -case class Time(millis: Long) { +/** + * This class is simple wrapper class that represents time in UTC. + * @param millis Time in UTC long + */ + +case class Time(private val millis: Long) { def < (that: Time): Boolean = (this.millis < that.millis) @@ -15,7 +20,9 @@ case class Time(millis: Long) { def - (that: Time): Time = Time(millis - that.millis) def * (times: Int): Time = Time(millis * times) - + + def / (that: Time): Long = millis / that.millis + def floor(that: Time): Time = { val t = that.millis val m = math.floor(this.millis / t).toLong diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala deleted file mode 100644 index e4d2a634f5..0000000000 --- a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala +++ /dev/null @@ -1,39 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.rdd.UnionRDD -import spark.storage.StorageLevel - - -class WindowedDStream[T: ClassManifest]( - parent: DStream[T], - _windowTime: Time, - _slideTime: Time) - extends DStream[T](parent.ssc) { - - if (!_windowTime.isMultipleOf(parent.slideTime)) - throw new Exception("The window duration of WindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") - - if (!_slideTime.isMultipleOf(parent.slideTime)) - throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") - - parent.persist(StorageLevel.MEMORY_ONLY_SER) - - def windowTime: Time = _windowTime - - override def dependencies = List(parent) - - override def slideTime: Time = _slideTime - - override def parentRememberDuration: Time = rememberDuration + windowTime - - override def compute(validTime: Time): Option[RDD[T]] = { - val currentWindow = Interval(validTime - windowTime + parent.slideTime, validTime) - Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala new file mode 100644 index 0000000000..2e427dadf7 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala @@ -0,0 +1,39 @@ +package spark.streaming.dstream + +import spark.{RDD, Partitioner} +import spark.rdd.CoGroupedRDD +import spark.streaming.{Time, DStream} + +class CoGroupedDStream[K : ClassManifest]( + parents: Seq[DStream[(_, _)]], + partitioner: Partitioner + ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different StreamingContexts") + } + + if (parents.map(_.slideTime).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideTime = parents.head.slideTime + + override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = { + val part = partitioner + val rdds = parents.flatMap(_.getOrCompute(validTime)) + if (rdds.size > 0) { + val q = new CoGroupedRDD[K](rdds, part) + Some(q) + } else { + None + } + } + +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala new file mode 100644 index 0000000000..41c3af4694 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala @@ -0,0 +1,19 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.streaming.{Time, StreamingContext} + +/** + * An input stream that always returns the same RDD on each timestep. Useful for testing. + */ +class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T]) + extends InputDStream[T](ssc_) { + + override def start() {} + + override def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + Some(rdd) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/dstream/DataHandler.scala b/streaming/src/main/scala/spark/streaming/dstream/DataHandler.scala new file mode 100644 index 0000000000..d737ba1ecc --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/DataHandler.scala @@ -0,0 +1,83 @@ +package spark.streaming.dstream + +import java.util.concurrent.ArrayBlockingQueue +import scala.collection.mutable.ArrayBuffer +import spark.Logging +import spark.streaming.util.{RecurringTimer, SystemClock} +import spark.storage.StorageLevel + + +/** + * This is a helper object that manages the data received from the socket. It divides + * the object received into small batches of 100s of milliseconds, pushes them as + * blocks into the block manager and reports the block IDs to the network input + * tracker. It starts two threads, one to periodically start a new batch and prepare + * the previous batch of as a block, the other to push the blocks into the block + * manager. + */ + class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel) + extends Serializable with Logging { + + case class Block(id: String, iterator: Iterator[T], metadata: Any = null) + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockStorageLevel = storageLevel + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def createBlock(blockId: String, iterator: Iterator[T]) : Block = { + new Block(blockId, iterator) + } + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + logInfo("Data handler stopped") + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval) + val newBlock = createBlock(blockId, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + } + } catch { + case ie: InterruptedException => + logInfo("Block interval timer thread interrupted") + case e: Exception => + receiver.stop() + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel) + } + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread interrupted") + case e: Exception => + receiver.stop() + } + } + } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala new file mode 100644 index 0000000000..8cdaff467b --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala @@ -0,0 +1,110 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.rdd.UnionRDD +import spark.streaming.{StreamingContext, Time} + +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} + +import scala.collection.mutable.HashSet + + +class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( + @transient ssc_ : StreamingContext, + directory: String, + filter: PathFilter = FileInputDStream.defaultPathFilter, + newFilesOnly: Boolean = true) + extends InputDStream[(K, V)](ssc_) { + + @transient private var path_ : Path = null + @transient private var fs_ : FileSystem = null + + var lastModTime = 0L + val lastModTimeFiles = new HashSet[String]() + + def path(): Path = { + if (path_ == null) path_ = new Path(directory) + path_ + } + + def fs(): FileSystem = { + if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + fs_ + } + + override def start() { + if (newFilesOnly) { + lastModTime = System.currentTimeMillis() + } else { + lastModTime = 0 + } + } + + override def stop() { } + + /** + * Finds the files that were modified since the last time this method was called and makes + * a union RDD out of them. Note that this maintains the list of files that were processed + * in the latest modification time in the previous call to this method. This is because the + * modification time returned by the FileStatus API seems to return times only at the + * granularity of seconds. Hence, new files may have the same modification time as the + * latest modification time in the previous call to this method and the list of files + * maintained is used to filter the one that have been processed. + */ + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + // Create the filter for selecting new files + val newFilter = new PathFilter() { + var latestModTime = 0L + val latestModTimeFiles = new HashSet[String]() + + def accept(path: Path): Boolean = { + if (!filter.accept(path)) { + return false + } else { + val modTime = fs.getFileStatus(path).getModificationTime() + if (modTime < lastModTime){ + return false + } else if (modTime == lastModTime && lastModTimeFiles.contains(path.toString)) { + return false + } + if (modTime > latestModTime) { + latestModTime = modTime + latestModTimeFiles.clear() + } + latestModTimeFiles += path.toString + return true + } + } + } + + val newFiles = fs.listStatus(path, newFilter) + logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) + if (newFiles.length > 0) { + // Update the modification time and the files processed for that modification time + if (lastModTime != newFilter.latestModTime) { + lastModTime = newFilter.latestModTime + lastModTimeFiles.clear() + } + lastModTimeFiles ++= newFilter.latestModTimeFiles + } + val newRDD = new UnionRDD(ssc.sc, newFiles.map( + file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) + Some(newRDD) + } +} + +object FileInputDStream { + val defaultPathFilter = new PathFilter with Serializable { + def accept(path: Path): Boolean = { + val file = path.getName() + if (file.startsWith(".") || file.endsWith("_tmp")) { + return false + } else { + return true + } + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala new file mode 100644 index 0000000000..1cbb4d536e --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala @@ -0,0 +1,21 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class FilteredDStream[T: ClassManifest]( + parent: DStream[T], + filterFunc: T => Boolean + ) extends DStream[T](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + parent.getOrCompute(validTime).map(_.filter(filterFunc)) + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala new file mode 100644 index 0000000000..11ed8cf317 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -0,0 +1,20 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD +import spark.SparkContext._ + +private[streaming] +class FlatMapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( + parent: DStream[(K, V)], + flatMapValueFunc: V => TraversableOnce[U] + ) extends DStream[(K, U)](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K, U)]] = { + parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc)) + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala new file mode 100644 index 0000000000..a13b4c9ff9 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala @@ -0,0 +1,20 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( + parent: DStream[T], + flatMapFunc: T => Traversable[U] + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala new file mode 100644 index 0000000000..7e988cadf4 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala @@ -0,0 +1,135 @@ +package spark.streaming.dstream + +import spark.streaming.StreamingContext + +import spark.Utils +import spark.storage.StorageLevel + +import org.apache.flume.source.avro.AvroSourceProtocol +import org.apache.flume.source.avro.AvroFlumeEvent +import org.apache.flume.source.avro.Status +import org.apache.avro.ipc.specific.SpecificResponder +import org.apache.avro.ipc.NettyServer + +import scala.collection.JavaConversions._ + +import java.net.InetSocketAddress +import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.nio.ByteBuffer + +class FlumeInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + storageLevel: StorageLevel +) extends NetworkInputDStream[SparkFlumeEvent](ssc_) { + + override def createReceiver(): NetworkReceiver[SparkFlumeEvent] = { + new FlumeReceiver(id, host, port, storageLevel) + } +} + +/** + * A wrapper class for AvroFlumeEvent's with a custom serialization format. + * + * This is necessary because AvroFlumeEvent uses inner data structures + * which are not serializable. + */ +class SparkFlumeEvent() extends Externalizable { + var event : AvroFlumeEvent = new AvroFlumeEvent() + + /* De-serialize from bytes. */ + def readExternal(in: ObjectInput) { + val bodyLength = in.readInt() + val bodyBuff = new Array[Byte](bodyLength) + in.read(bodyBuff) + + val numHeaders = in.readInt() + val headers = new java.util.HashMap[CharSequence, CharSequence] + + for (i <- 0 until numHeaders) { + val keyLength = in.readInt() + val keyBuff = new Array[Byte](keyLength) + in.read(keyBuff) + val key : String = Utils.deserialize(keyBuff) + + val valLength = in.readInt() + val valBuff = new Array[Byte](valLength) + in.read(valBuff) + val value : String = Utils.deserialize(valBuff) + + headers.put(key, value) + } + + event.setBody(ByteBuffer.wrap(bodyBuff)) + event.setHeaders(headers) + } + + /* Serialize to bytes. */ + def writeExternal(out: ObjectOutput) { + val body = event.getBody.array() + out.writeInt(body.length) + out.write(body) + + val numHeaders = event.getHeaders.size() + out.writeInt(numHeaders) + for ((k, v) <- event.getHeaders) { + val keyBuff = Utils.serialize(k.toString) + out.writeInt(keyBuff.length) + out.write(keyBuff) + val valBuff = Utils.serialize(v.toString) + out.writeInt(valBuff.length) + out.write(valBuff) + } + } +} + +private[streaming] object SparkFlumeEvent { + def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { + val event = new SparkFlumeEvent + event.event = in + event + } +} + +/** A simple server that implements Flume's Avro protocol. */ +class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { + override def append(event : AvroFlumeEvent) : Status = { + receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event) + Status.OK + } + + override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { + events.foreach (event => + receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event)) + Status.OK + } +} + +/** A NetworkReceiver which listens for events using the + * Flume Avro interface.*/ +class FlumeReceiver( + streamId: Int, + host: String, + port: Int, + storageLevel: StorageLevel + ) extends NetworkReceiver[SparkFlumeEvent](streamId) { + + lazy val dataHandler = new DataHandler(this, storageLevel) + + protected override def onStart() { + val responder = new SpecificResponder( + classOf[AvroSourceProtocol], new FlumeEventServer(this)); + val server = new NettyServer(responder, new InetSocketAddress(host, port)); + dataHandler.start() + server.start() + logInfo("Flume receiver started") + } + + protected override def onStop() { + dataHandler.stop() + logInfo("Flume receiver stopped") + } + + override def getLocationPreference = Some(host) +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala new file mode 100644 index 0000000000..41c629a225 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala @@ -0,0 +1,28 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.streaming.{DStream, Job, Time} + +private[streaming] +class ForEachDStream[T: ClassManifest] ( + parent: DStream[T], + foreachFunc: (RDD[T], Time) => Unit + ) extends DStream[Unit](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Unit]] = None + + override def generateJob(time: Time): Option[Job] = { + parent.getOrCompute(time) match { + case Some(rdd) => + val jobFunc = () => { + foreachFunc(rdd, time) + } + Some(new Job(time, jobFunc)) + case None => None + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala new file mode 100644 index 0000000000..92ea503cae --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala @@ -0,0 +1,17 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class GlommedDStream[T: ClassManifest](parent: DStream[T]) + extends DStream[Array[T]](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Array[T]]] = { + parent.getOrCompute(validTime).map(_.glom()) + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala new file mode 100644 index 0000000000..4959c66b06 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala @@ -0,0 +1,19 @@ +package spark.streaming.dstream + +import spark.streaming.{StreamingContext, DStream} + +abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) + extends DStream[T](ssc_) { + + override def dependencies = List() + + override def slideTime = { + if (ssc == null) throw new Exception("ssc is null") + if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null") + ssc.graph.batchDuration + } + + def start() + + def stop() +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala new file mode 100644 index 0000000000..a46721af2f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -0,0 +1,197 @@ +package spark.streaming.dstream + +import spark.Logging +import spark.storage.StorageLevel +import spark.streaming.{Time, DStreamCheckpointData, StreamingContext} + +import java.util.Properties +import java.util.concurrent.Executors + +import kafka.consumer._ +import kafka.message.{Message, MessageSet, MessageAndMetadata} +import kafka.serializer.StringDecoder +import kafka.utils.{Utils, ZKGroupTopicDirs} +import kafka.utils.ZkUtils._ + +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ + + +// Key for a specific Kafka Partition: (broker, topic, group, part) +case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) +// NOT USED - Originally intended for fault-tolerance +// Metadata for a Kafka Stream that it sent to the Master +case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) +// NOT USED - Originally intended for fault-tolerance +// Checkpoint data specific to a KafkaInputDstream +case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], + savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds) + +/** + * Input stream that pulls messages from a Kafka Broker. + * + * @param host Zookeper hostname. + * @param port Zookeper port. + * @param groupId The group id for this consumer. + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param initialOffsets Optional initial offsets for each of the partitions to consume. + * By default the value is pulled from zookeper. + * @param storageLevel RDD storage level. + */ +class KafkaInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + groupId: String, + topics: Map[String, Int], + initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_ ) with Logging { + + // Metadata that keeps track of which messages have already been consumed. + var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]() + + /* NOT USED - Originally intended for fault-tolerance + + // In case of a failure, the offets for a particular timestamp will be restored. + @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null + + + override protected[streaming] def addMetadata(metadata: Any) { + metadata match { + case x : KafkaInputDStreamMetadata => + savedOffsets(x.timestamp) = x.data + // TOOD: Remove logging + logInfo("New saved Offsets: " + savedOffsets) + case _ => logInfo("Received unknown metadata: " + metadata.toString) + } + } + + override protected[streaming] def updateCheckpointData(currentTime: Time) { + super.updateCheckpointData(currentTime) + if(savedOffsets.size > 0) { + // Find the offets that were stored before the checkpoint was initiated + val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last + val latestOffsets = savedOffsets(key) + logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString) + checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets) + // TODO: This may throw out offsets that are created after the checkpoint, + // but it's unlikely we'll need them. + savedOffsets.clear() + } + } + + override protected[streaming] def restoreCheckpointData() { + super.restoreCheckpointData() + logInfo("Restoring KafkaDStream checkpoint data.") + checkpointData match { + case x : KafkaDStreamCheckpointData => + restoredOffsets = x.savedOffsets + logInfo("Restored KafkaDStream offsets: " + savedOffsets) + } + } */ + + def createReceiver(): NetworkReceiver[T] = { + new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) + .asInstanceOf[NetworkReceiver[T]] + } +} + +class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, + topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) { + + // Timeout for establishing a connection to Zookeper in ms. + val ZK_TIMEOUT = 10000 + + // Handles pushing data into the BlockManager + lazy protected val dataHandler = new DataHandler(this, storageLevel) + // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset + lazy val offsets = HashMap[KafkaPartitionKey, Long]() + // Connection to Kafka + var consumerConnector : ZookeeperConsumerConnector = null + + def onStop() { + dataHandler.stop() + } + + def onStart() { + + // Starting the DataHandler that buffers blocks and pushes them into them BlockManager + dataHandler.start() + + // In case we are using multiple Threads to handle Kafka Messages + val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) + + val zooKeeperEndPoint = host + ":" + port + logInfo("Starting Kafka Consumer Stream with group: " + groupId) + logInfo("Initial offsets: " + initialOffsets.toString) + + // Zookeper connection properties + val props = new Properties() + props.put("zk.connect", zooKeeperEndPoint) + props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) + props.put("groupid", groupId) + + // Create the connection to the cluster + logInfo("Connecting to Zookeper: " + zooKeeperEndPoint) + val consumerConfig = new ConsumerConfig(props) + consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] + logInfo("Connected to " + zooKeeperEndPoint) + + // Reset the Kafka offsets in case we are recovering from a failure + resetOffsets(initialOffsets) + + // Create Threads for each Topic/Message Stream we are listening + val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) + + // Start the messages handler for each partition + topicMessageStreams.values.foreach { streams => + streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } + } + + } + + // Overwrites the offets in Zookeper. + private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) { + offsets.foreach { case(key, offset) => + val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) + val partitionName = key.brokerId + "-" + key.partId + updatePersistentPath(consumerConnector.zkClient, + topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString) + } + } + + // Handles Kafka Messages + private class MessageHandler(stream: KafkaStream[String]) extends Runnable { + def run() { + logInfo("Starting MessageHandler.") + stream.takeWhile { msgAndMetadata => + dataHandler += msgAndMetadata.message + + // Updating the offet. The key is (broker, topic, group, partition). + val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic, + groupId, msgAndMetadata.topicInfo.partition.partId) + val offset = msgAndMetadata.topicInfo.getConsumeOffset + offsets.put(key, offset) + // logInfo("Handled message: " + (key, offset).toString) + + // Keep on handling messages + true + } + } + } + + // NOT USED - Originally intended for fault-tolerance + // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) + // extends DataHandler[Any](receiver, storageLevel) { + + // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { + // // Creates a new Block with Kafka-specific Metadata + // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) + // } + + // } + +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala new file mode 100644 index 0000000000..daf78c6893 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala @@ -0,0 +1,21 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( + parent: DStream[T], + mapPartFunc: Iterator[T] => Iterator[U], + preservePartitioning: Boolean + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning)) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala new file mode 100644 index 0000000000..689caeef0e --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala @@ -0,0 +1,21 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD +import spark.SparkContext._ + +private[streaming] +class MapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( + parent: DStream[(K, V)], + mapValueFunc: V => U + ) extends DStream[(K, U)](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K, U)]] = { + parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc)) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala new file mode 100644 index 0000000000..786b9966f2 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala @@ -0,0 +1,20 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class MappedDStream[T: ClassManifest, U: ClassManifest] ( + parent: DStream[T], + mapFunc: T => U + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.map[U](mapFunc)) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala new file mode 100644 index 0000000000..41276da8bb --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala @@ -0,0 +1,157 @@ +package spark.streaming.dstream + +import spark.streaming.{Time, StreamingContext, AddBlocks, RegisterReceiver, DeregisterReceiver} + +import spark.{Logging, SparkEnv, RDD} +import spark.rdd.BlockRDD +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer + +import java.nio.ByteBuffer + +import akka.actor.{Props, Actor} +import akka.pattern.ask +import akka.dispatch.Await +import akka.util.duration._ + +abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext) + extends InputDStream[T](ssc_) { + + // This is an unique identifier that is used to match the network receiver with the + // corresponding network input stream. + val id = ssc.getNewNetworkStreamId() + + /** + * This method creates the receiver object that will be sent to the workers + * to receive data. This method needs to defined by any specific implementation + * of a NetworkInputDStream. + */ + def createReceiver(): NetworkReceiver[T] + + // Nothing to start or stop as both taken care of by the NetworkInputTracker. + def start() {} + + def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) + Some(new BlockRDD[T](ssc.sc, blockIds)) + } +} + + +private[streaming] sealed trait NetworkReceiverMessage +private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage +private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage +private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage + +abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging { + + initLogging() + + lazy protected val env = SparkEnv.get + + lazy protected val actor = env.actorSystem.actorOf( + Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) + + lazy protected val receivingThread = Thread.currentThread() + + /** This method will be called to start receiving data. */ + protected def onStart() + + /** This method will be called to stop receiving data. */ + protected def onStop() + + /** This method conveys a placement preference (hostname) for this receiver. */ + def getLocationPreference() : Option[String] = None + + /** + * This method starts the receiver. First is accesses all the lazy members to + * materialize them. Then it calls the user-defined onStart() method to start + * other threads, etc required to receiver the data. + */ + def start() { + try { + // Access the lazy vals to materialize them + env + actor + receivingThread + + // Call user-defined onStart() + onStart() + } catch { + case ie: InterruptedException => + logInfo("Receiving thread interrupted") + //println("Receiving thread interrupted") + case e: Exception => + stopOnError(e) + } + } + + /** + * This method stops the receiver. First it interrupts the main receiving thread, + * that is, the thread that called receiver.start(). Then it calls the user-defined + * onStop() method to stop other threads and/or do cleanup. + */ + def stop() { + receivingThread.interrupt() + onStop() + //TODO: terminate the actor + } + + /** + * This method stops the receiver and reports to exception to the tracker. + * This should be called whenever an exception has happened on any thread + * of the receiver. + */ + protected def stopOnError(e: Exception) { + logError("Error receiving data", e) + stop() + actor ! ReportError(e.toString) + } + + + /** + * This method pushes a block (as iterator of values) into the block manager. + */ + def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) { + val buffer = new ArrayBuffer[T] ++ iterator + env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level) + + actor ! ReportBlock(blockId, metadata) + } + + /** + * This method pushes a block (as bytes) into the block manager. + */ + def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { + env.blockManager.putBytes(blockId, bytes, level) + actor ! ReportBlock(blockId, metadata) + } + + /** A helper actor that communicates with the NetworkInputTracker */ + private class NetworkReceiverActor extends Actor { + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) + val tracker = env.actorSystem.actorFor(url) + val timeout = 5.seconds + + override def preStart() { + val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + override def receive() = { + case ReportBlock(blockId, metadata) => + tracker ! AddBlocks(streamId, Array(blockId), metadata) + case ReportError(msg) => + tracker ! DeregisterReceiver(streamId, msg) + case StopReceiver(msg) => + stop() + tracker ! DeregisterReceiver(streamId, msg) + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala new file mode 100644 index 0000000000..024bf3bea4 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala @@ -0,0 +1,41 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.rdd.UnionRDD + +import scala.collection.mutable.Queue +import scala.collection.mutable.ArrayBuffer +import spark.streaming.{Time, StreamingContext} + +class QueueInputDStream[T: ClassManifest]( + @transient ssc: StreamingContext, + val queue: Queue[RDD[T]], + oneAtATime: Boolean, + defaultRDD: RDD[T] + ) extends InputDStream[T](ssc) { + + override def start() { } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[T]] = { + val buffer = new ArrayBuffer[RDD[T]]() + if (oneAtATime && queue.size > 0) { + buffer += queue.dequeue() + } else { + buffer ++= queue + } + if (buffer.size > 0) { + if (oneAtATime) { + Some(buffer.first) + } else { + Some(new UnionRDD(ssc.sc, buffer.toSeq)) + } + } else if (defaultRDD != null) { + Some(defaultRDD) + } else { + None + } + } + +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala new file mode 100644 index 0000000000..996cc7dea8 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala @@ -0,0 +1,88 @@ +package spark.streaming.dstream + +import spark.{DaemonThread, Logging} +import spark.storage.StorageLevel +import spark.streaming.StreamingContext + +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, SocketChannel} +import java.io.EOFException +import java.util.concurrent.ArrayBlockingQueue + + +/** + * An input stream that reads blocks of serialized objects from a given network address. + * The blocks will be inserted directly into the block store. This is the fastest way to get + * data into Spark Streaming, though it requires the sender to batch data and serialize it + * in the format that the system is configured with. + */ +class RawInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_ ) with Logging { + + def createReceiver(): NetworkReceiver[T] = { + new RawNetworkReceiver(id, host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] + } +} + +class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel) + extends NetworkReceiver[Any](streamId) { + + var blockPushingThread: Thread = null + + override def getLocationPreference = None + + def onStart() { + // Open a socket to the target address and keep reading from it + logInfo("Connecting to " + host + ":" + port) + val channel = SocketChannel.open() + channel.configureBlocking(true) + channel.connect(new InetSocketAddress(host, port)) + logInfo("Connected to " + host + ":" + port) + + val queue = new ArrayBlockingQueue[ByteBuffer](2) + + blockPushingThread = new DaemonThread { + override def run() { + var nextBlockNumber = 0 + while (true) { + val buffer = queue.take() + val blockId = "input-" + streamId + "-" + nextBlockNumber + nextBlockNumber += 1 + pushBlock(blockId, buffer, null, storageLevel) + } + } + } + blockPushingThread.start() + + val lengthBuffer = ByteBuffer.allocate(4) + while (true) { + lengthBuffer.clear() + readFully(channel, lengthBuffer) + lengthBuffer.flip() + val length = lengthBuffer.getInt() + val dataBuffer = ByteBuffer.allocate(length) + readFully(channel, dataBuffer) + dataBuffer.flip() + logInfo("Read a block with " + length + " bytes") + queue.put(dataBuffer) + } + } + + def onStop() { + if (blockPushingThread != null) blockPushingThread.interrupt() + } + + /** Read a buffer fully from a given Channel */ + private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { + while (dest.position < dest.limit) { + if (channel.read(dest) == -1) { + throw new EOFException("End of channel") + } + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala new file mode 100644 index 0000000000..2686de14d2 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -0,0 +1,148 @@ +package spark.streaming.dstream + +import spark.streaming.StreamingContext._ + +import spark.RDD +import spark.rdd.CoGroupedRDD +import spark.Partitioner +import spark.SparkContext._ +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer +import spark.streaming.{Interval, Time, DStream} + +class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( + parent: DStream[(K, V)], + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + _windowTime: Time, + _slideTime: Time, + partitioner: Partitioner + ) extends DStream[(K,V)](parent.ssc) { + + assert(_windowTime.isMultipleOf(parent.slideTime), + "The window duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) + + assert(_slideTime.isMultipleOf(parent.slideTime), + "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) + + // Reduce each batch of data using reduceByKey which will be further reduced by window + // by ReducedWindowedDStream + val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + + // Persist RDDs to memory by default as these RDDs are going to be reused. + super.persist(StorageLevel.MEMORY_ONLY_SER) + reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) + + def windowTime: Time = _windowTime + + override def dependencies = List(reducedStream) + + override def slideTime: Time = _slideTime + + override val mustCheckpoint = true + + override def parentRememberDuration: Time = rememberDuration + windowTime + + override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + super.persist(storageLevel) + reducedStream.persist(storageLevel) + this + } + + override def checkpoint(interval: Time): DStream[(K, V)] = { + super.checkpoint(interval) + //reducedStream.checkpoint(interval) + this + } + + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val reduceF = reduceFunc + val invReduceF = invReduceFunc + + val currentTime = validTime + val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime) + val previousWindow = currentWindow - slideTime + + logDebug("Window time = " + windowTime) + logDebug("Slide time = " + slideTime) + logDebug("ZeroTime = " + zeroTime) + logDebug("Current window = " + currentWindow) + logDebug("Previous window = " + previousWindow) + + // _____________________________ + // | previous window _________|___________________ + // |___________________| current window | --------------> Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old RDDs new RDDs + // + + // Get the RDDs of the reduced values in "old time steps" + val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideTime) + logDebug("# old RDDs = " + oldRDDs.size) + + // Get the RDDs of the reduced values in "new time steps" + val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideTime, currentWindow.endTime) + logDebug("# new RDDs = " + newRDDs.size) + + // Get the RDD of the reduced value of the previous window + val previousWindowRDD = getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) + + // Make the list of RDDs that needs to cogrouped together for reducing their reduced values + val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs + + // Cogroup the reduced RDDs and merge the reduced values + val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) + //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ + + val numOldValues = oldRDDs.size + val numNewValues = newRDDs.size + + val mergeValues = (seqOfValues: Seq[Seq[V]]) => { + if (seqOfValues.size != 1 + numOldValues + numNewValues) { + throw new Exception("Unexpected number of sequences of reduced values") + } + // Getting reduced values "old time steps" that will be removed from current window + val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) + // Getting reduced values "new time steps" + val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + if (seqOfValues(0).isEmpty) { + // If previous window's reduce value does not exist, then at least new values should exist + if (newValues.isEmpty) { + throw new Exception("Neither previous window has value for key, nor new values found. " + + "Are you sure your key class hashes consistently?") + } + // Reduce the new values + newValues.reduce(reduceF) // return + } else { + // Get the previous window's reduced value + var tempValue = seqOfValues(0).head + // If old values exists, then inverse reduce then from previous value + if (!oldValues.isEmpty) { + tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) + } + // If new values exists, then reduce them with previous value + if (!newValues.isEmpty) { + tempValue = reduceF(tempValue, newValues.reduce(reduceF)) + } + tempValue // return + } + } + + val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) + + Some(mergedValuesRDD) + } + + +} + + diff --git a/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala new file mode 100644 index 0000000000..6854bbe665 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala @@ -0,0 +1,27 @@ +package spark.streaming.dstream + +import spark.{RDD, Partitioner} +import spark.SparkContext._ +import spark.streaming.{DStream, Time} + +private[streaming] +class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( + parent: DStream[(K,V)], + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + partitioner: Partitioner + ) extends DStream [(K,C)] (parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K,C)]] = { + parent.getOrCompute(validTime) match { + case Some(rdd) => + Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner)) + case None => None + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala new file mode 100644 index 0000000000..af5b73ae8d --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala @@ -0,0 +1,103 @@ +package spark.streaming.dstream + +import spark.streaming.StreamingContext +import spark.storage.StorageLevel + +import java.io._ +import java.net.Socket + +class SocketInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_) { + + def createReceiver(): NetworkReceiver[T] = { + new SocketReceiver(id, host, port, bytesToObjects, storageLevel) + } +} + + +class SocketReceiver[T: ClassManifest]( + streamId: Int, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkReceiver[T](streamId) { + + lazy protected val dataHandler = new DataHandler(this, storageLevel) + + override def getLocationPreference = None + + protected def onStart() { + logInfo("Connecting to " + host + ":" + port) + val socket = new Socket(host, port) + logInfo("Connected to " + host + ":" + port) + dataHandler.start() + val iterator = bytesToObjects(socket.getInputStream()) + while(iterator.hasNext) { + val obj = iterator.next + dataHandler += obj + } + } + + protected def onStop() { + dataHandler.stop() + } + +} + + +object SocketReceiver { + + /** + * This methods translates the data from an inputstream (say, from a socket) + * to '\n' delimited strings and returns an iterator to access the strings. + */ + def bytesToLines(inputStream: InputStream): Iterator[String] = { + val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")) + + val iterator = new Iterator[String] { + var gotNext = false + var finished = false + var nextValue: String = null + + private def getNext() { + try { + nextValue = dataInputStream.readLine() + if (nextValue == null) { + finished = true + } + } + gotNext = true + } + + override def hasNext: Boolean = { + if (!finished) { + if (!gotNext) { + getNext() + if (finished) { + dataInputStream.close() + } + } + } + !finished + } + + override def next(): String = { + if (finished) { + throw new NoSuchElementException("End of stream") + } + if (!gotNext) { + getNext() + } + gotNext = false + nextValue + } + } + iterator + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala new file mode 100644 index 0000000000..6e190b5564 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala @@ -0,0 +1,83 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.Partitioner +import spark.SparkContext._ +import spark.storage.StorageLevel +import spark.streaming.{Time, DStream} + +class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( + parent: DStream[(K, V)], + updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], + partitioner: Partitioner, + preservePartitioning: Boolean + ) extends DStream[(K, S)](parent.ssc) { + + super.persist(StorageLevel.MEMORY_ONLY_SER) + + override def dependencies = List(parent) + + override def slideTime = parent.slideTime + + override val mustCheckpoint = true + + override def compute(validTime: Time): Option[RDD[(K, S)]] = { + + // Try to get the previous state RDD + getOrCompute(validTime - slideTime) match { + + case Some(prevStateRDD) => { // If previous state RDD exists + + // Try to get the parent RDD + parent.getOrCompute(validTime) match { + case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on cogrouped RDD; + // first map the cogrouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val i = iterator.map(t => { + (t._1, t._2._1, t._2._2.headOption) + }) + updateFuncLocal(i) + } + val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) + val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) + //logDebug("Generating state RDD for time " + validTime) + return Some(stateRDD) + } + case None => { // If parent RDD does not exist, then return old state RDD + return Some(prevStateRDD) + } + } + } + + case None => { // If previous session RDD does not exist (first input data) + + // Try to get the parent RDD + parent.getOrCompute(validTime) match { + case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on grouped RDD; + // first map the grouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None))) + } + + val groupedRDD = parentRDD.groupByKey(partitioner) + val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) + //logDebug("Generating state RDD for time " + validTime + " (first)") + return Some(sessionRDD) + } + case None => { // If parent RDD does not exist, then nothing to do! + //logDebug("Not generating state RDD (no previous state, no parent)") + return None + } + } + } + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala new file mode 100644 index 0000000000..0337579514 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala @@ -0,0 +1,19 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.streaming.{DStream, Time} + +private[streaming] +class TransformedDStream[T: ClassManifest, U: ClassManifest] ( + parent: DStream[T], + transformFunc: (RDD[T], Time) => RDD[U] + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(transformFunc(_, validTime)) + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala new file mode 100644 index 0000000000..f1efb2ae72 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala @@ -0,0 +1,39 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD +import collection.mutable.ArrayBuffer +import spark.rdd.UnionRDD + +class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) + extends DStream[T](parents.head.ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different StreamingContexts") + } + + if (parents.map(_.slideTime).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideTime: Time = parents.head.slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val rdds = new ArrayBuffer[RDD[T]]() + parents.map(_.getOrCompute(validTime)).foreach(_ match { + case Some(rdd) => rdds += rdd + case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) + }) + if (rdds.size > 0) { + Some(new UnionRDD(ssc.sc, rdds)) + } else { + None + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala new file mode 100644 index 0000000000..4b2621c497 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala @@ -0,0 +1,40 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.rdd.UnionRDD +import spark.storage.StorageLevel +import spark.streaming.{Interval, Time, DStream} + + +class WindowedDStream[T: ClassManifest]( + parent: DStream[T], + _windowTime: Time, + _slideTime: Time) + extends DStream[T](parent.ssc) { + + if (!_windowTime.isMultipleOf(parent.slideTime)) + throw new Exception("The window duration of WindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + if (!_slideTime.isMultipleOf(parent.slideTime)) + throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + parent.persist(StorageLevel.MEMORY_ONLY_SER) + + def windowTime: Time = _windowTime + + override def dependencies = List(parent) + + override def slideTime: Time = _slideTime + + override def parentRememberDuration: Time = rememberDuration + windowTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val currentWindow = Interval(validTime - windowTime + parent.slideTime, validTime) + Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala index 7c4ee3b34c..dfaaf03f03 100644 --- a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -25,7 +25,7 @@ object GrepRaw { val rawStreams = (1 to numStreams).map(_ => ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray - val union = new UnionDStream(rawStreams) + val union = ssc.union(rawStreams) union.filter(_.contains("Alice")).count().foreach(r => println("Grep count: " + r.collect().mkString)) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index 182dfd8a52..338834bc3c 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -34,7 +34,7 @@ object TopKWordCountRaw { val lines = (1 to numStreams).map(_ => { ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) }) - val union = new UnionDStream(lines.toArray) + val union = ssc.union(lines) val counts = union.mapPartitions(splitAndCountPartitions) val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index 9bcd30f4d7..d93335a8ce 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -33,7 +33,7 @@ object WordCountRaw { val lines = (1 to numStreams).map(_ => { ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) }) - val union = new UnionDStream(lines.toArray) + val union = ssc.union(lines) val counts = union.mapPartitions(splitAndCountPartitions) val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) windowedCounts.foreach(r => println("# unique words = " + r.count())) diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala deleted file mode 100644 index 7c642d4802..0000000000 --- a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala +++ /dev/null @@ -1,193 +0,0 @@ -package spark.streaming - -import java.util.Properties -import java.util.concurrent.Executors -import kafka.consumer._ -import kafka.message.{Message, MessageSet, MessageAndMetadata} -import kafka.serializer.StringDecoder -import kafka.utils.{Utils, ZKGroupTopicDirs} -import kafka.utils.ZkUtils._ -import scala.collection.mutable.HashMap -import scala.collection.JavaConversions._ -import spark._ -import spark.RDD -import spark.storage.StorageLevel - -// Key for a specific Kafka Partition: (broker, topic, group, part) -case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) -// NOT USED - Originally intended for fault-tolerance -// Metadata for a Kafka Stream that it sent to the Master -case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) -// NOT USED - Originally intended for fault-tolerance -// Checkpoint data specific to a KafkaInputDstream -case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], - savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds) - -/** - * Input stream that pulls messages form a Kafka Broker. - * - * @param host Zookeper hostname. - * @param port Zookeper port. - * @param groupId The group id for this consumer. - * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. - * @param storageLevel RDD storage level. - */ -class KafkaInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - groupId: String, - topics: Map[String, Int], - initialOffsets: Map[KafkaPartitionKey, Long], - storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_ ) with Logging { - - // Metadata that keeps track of which messages have already been consumed. - var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]() - - /* NOT USED - Originally intended for fault-tolerance - - // In case of a failure, the offets for a particular timestamp will be restored. - @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null - - - override protected[streaming] def addMetadata(metadata: Any) { - metadata match { - case x : KafkaInputDStreamMetadata => - savedOffsets(x.timestamp) = x.data - // TOOD: Remove logging - logInfo("New saved Offsets: " + savedOffsets) - case _ => logInfo("Received unknown metadata: " + metadata.toString) - } - } - - override protected[streaming] def updateCheckpointData(currentTime: Time) { - super.updateCheckpointData(currentTime) - if(savedOffsets.size > 0) { - // Find the offets that were stored before the checkpoint was initiated - val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last - val latestOffsets = savedOffsets(key) - logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString) - checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets) - // TODO: This may throw out offsets that are created after the checkpoint, - // but it's unlikely we'll need them. - savedOffsets.clear() - } - } - - override protected[streaming] def restoreCheckpointData() { - super.restoreCheckpointData() - logInfo("Restoring KafkaDStream checkpoint data.") - checkpointData match { - case x : KafkaDStreamCheckpointData => - restoredOffsets = x.savedOffsets - logInfo("Restored KafkaDStream offsets: " + savedOffsets) - } - } */ - - def createReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) - .asInstanceOf[NetworkReceiver[T]] - } -} - -class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, - topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], - storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) { - - // Timeout for establishing a connection to Zookeper in ms. - val ZK_TIMEOUT = 10000 - - // Handles pushing data into the BlockManager - lazy protected val dataHandler = new DataHandler(this, storageLevel) - // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset - lazy val offsets = HashMap[KafkaPartitionKey, Long]() - // Connection to Kafka - var consumerConnector : ZookeeperConsumerConnector = null - - def onStop() { - dataHandler.stop() - } - - def onStart() { - - // Starting the DataHandler that buffers blocks and pushes them into them BlockManager - dataHandler.start() - - // In case we are using multiple Threads to handle Kafka Messages - val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) - - val zooKeeperEndPoint = host + ":" + port - logInfo("Starting Kafka Consumer Stream with group: " + groupId) - logInfo("Initial offsets: " + initialOffsets.toString) - - // Zookeper connection properties - val props = new Properties() - props.put("zk.connect", zooKeeperEndPoint) - props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) - props.put("groupid", groupId) - - // Create the connection to the cluster - logInfo("Connecting to Zookeper: " + zooKeeperEndPoint) - val consumerConfig = new ConsumerConfig(props) - consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] - logInfo("Connected to " + zooKeeperEndPoint) - - // Reset the Kafka offsets in case we are recovering from a failure - resetOffsets(initialOffsets) - - // Create Threads for each Topic/Message Stream we are listening - val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) - - // Start the messages handler for each partition - topicMessageStreams.values.foreach { streams => - streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } - } - - } - - // Overwrites the offets in Zookeper. - private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) { - offsets.foreach { case(key, offset) => - val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) - val partitionName = key.brokerId + "-" + key.partId - updatePersistentPath(consumerConnector.zkClient, - topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString) - } - } - - // Handles Kafka Messages - private class MessageHandler(stream: KafkaStream[String]) extends Runnable { - def run() { - logInfo("Starting MessageHandler.") - stream.takeWhile { msgAndMetadata => - dataHandler += msgAndMetadata.message - - // Updating the offet. The key is (broker, topic, group, partition). - val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic, - groupId, msgAndMetadata.topicInfo.partition.partId) - val offset = msgAndMetadata.topicInfo.getConsumeOffset - offsets.put(key, offset) - // logInfo("Handled message: " + (key, offset).toString) - - // Keep on handling messages - true - } - } - } - - // NOT USED - Originally intended for fault-tolerance - // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) - // extends DataHandler[Any](receiver, storageLevel) { - - // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { - // // Creates a new Block with Kafka-specific Metadata - // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) - // } - - // } - -} diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 0d82b2f1ea..920388bba9 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -42,7 +42,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val stateStreamCheckpointInterval = Seconds(1) // this ensure checkpointing occurs at least once - val firstNumBatches = (stateStreamCheckpointInterval.millis / batchDuration.millis) * 2 + val firstNumBatches = (stateStreamCheckpointInterval / batchDuration) * 2 val secondNumBatches = firstNumBatches // Setup the streams diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index 5b414117fc..4aa428bf64 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -133,7 +133,7 @@ class FailureSuite extends TestSuiteBase with BeforeAndAfter { // Get the output buffer val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] val output = outputStream.output - val waitTime = (batchDuration.millis * (numBatches.toDouble + 0.5)).toLong + val waitTime = (batchDuration.milliseconds * (numBatches.toDouble + 0.5)).toLong val startTime = System.currentTimeMillis() try { diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index ed9a659092..76b528bec3 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -1,5 +1,6 @@ package spark.streaming +import dstream.SparkFlumeEvent import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket} import java.io.{File, BufferedWriter, OutputStreamWriter} import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index a44f738957..28bdd53c3c 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -1,12 +1,16 @@ package spark.streaming +import spark.streaming.dstream.{InputDStream, ForEachDStream} +import spark.streaming.util.ManualClock + import spark.{RDD, Logging} -import util.ManualClock + import collection.mutable.ArrayBuffer -import org.scalatest.FunSuite import collection.mutable.SynchronizedBuffer + import java.io.{ObjectInputStream, IOException} +import org.scalatest.FunSuite /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, @@ -70,6 +74,10 @@ trait TestSuiteBase extends FunSuite with Logging { def actuallyWait = false + /** + * Set up required DStreams to test the DStream operation using the two sequences + * of input collections. + */ def setupStreams[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V] @@ -90,6 +98,10 @@ trait TestSuiteBase extends FunSuite with Logging { ssc } + /** + * Set up required DStreams to test the binary operation using the sequence + * of input collections. + */ def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest]( input1: Seq[Seq[U]], input2: Seq[Seq[V]], @@ -173,6 +185,11 @@ trait TestSuiteBase extends FunSuite with Logging { output } + /** + * Verify whether the output values after running a DStream operation + * is same as the expected output values, by comparing the output + * collections either as lists (order matters) or sets (order does not matter) + */ def verifyOutput[V: ClassManifest]( output: Seq[Seq[V]], expectedOutput: Seq[Seq[V]], @@ -199,6 +216,10 @@ trait TestSuiteBase extends FunSuite with Logging { logInfo("Output verified successfully") } + /** + * Test unary DStream operation with a list of inputs, with number of + * batches to run same as the number of expected output values + */ def testOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], @@ -208,6 +229,15 @@ trait TestSuiteBase extends FunSuite with Logging { testOperation[U, V](input, operation, expectedOutput, -1, useSet) } + /** + * Test unary DStream operation with a list of inputs + * @param input Sequence of input collections + * @param operation Binary DStream operation to be applied to the 2 inputs + * @param expectedOutput Sequence of expected output collections + * @param numBatches Number of batches to run the operation for + * @param useSet Compare the output values with the expected output values + * as sets (order matters) or as lists (order does not matter) + */ def testOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], @@ -221,6 +251,10 @@ trait TestSuiteBase extends FunSuite with Logging { verifyOutput[V](output, expectedOutput, useSet) } + /** + * Test binary DStream operation with two lists of inputs, with number of + * batches to run same as the number of expected output values + */ def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( input1: Seq[Seq[U]], input2: Seq[Seq[V]], @@ -231,6 +265,16 @@ trait TestSuiteBase extends FunSuite with Logging { testOperation[U, V, W](input1, input2, operation, expectedOutput, -1, useSet) } + /** + * Test binary DStream operation with two lists of inputs + * @param input1 First sequence of input collections + * @param input2 Second sequence of input collections + * @param operation Binary DStream operation to be applied to the 2 inputs + * @param expectedOutput Sequence of expected output collections + * @param numBatches Number of batches to run the operation for + * @param useSet Compare the output values with the expected output values + * as sets (order matters) or as lists (order does not matter) + */ def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( input1: Seq[Seq[U]], input2: Seq[Seq[V]], diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index 3e20e16708..4bc5229465 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -209,7 +209,7 @@ class WindowOperationsSuite extends TestSuiteBase { val expectedOutput = bigGroupByOutput.map(_.map(x => (x._1, x._2.toSet))) val windowTime = Seconds(2) val slideTime = Seconds(1) - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.groupByKeyAndWindow(windowTime, slideTime) .map(x => (x._1, x._2.toSet)) @@ -223,7 +223,7 @@ class WindowOperationsSuite extends TestSuiteBase { val expectedOutput = Seq( Seq(1), Seq(2), Seq(3), Seq(3), Seq(1), Seq(0)) val windowTime = Seconds(2) val slideTime = Seconds(1) - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[Int]) => s.countByWindow(windowTime, slideTime) testOperation(input, operation, expectedOutput, numBatches, true) } @@ -233,7 +233,7 @@ class WindowOperationsSuite extends TestSuiteBase { val expectedOutput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 2)), Seq(("a", 1), ("b", 3))) val windowTime = Seconds(2) val slideTime = Seconds(1) - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.countByKeyAndWindow(windowTime, slideTime).map(x => (x._1, x._2.toInt)) } @@ -251,7 +251,7 @@ class WindowOperationsSuite extends TestSuiteBase { slideTime: Time = Seconds(1) ) { test("window - " + name) { - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[Int]) => s.window(windowTime, slideTime) testOperation(input, operation, expectedOutput, numBatches, true) } @@ -265,7 +265,7 @@ class WindowOperationsSuite extends TestSuiteBase { slideTime: Time = Seconds(1) ) { test("reduceByKeyAndWindow - " + name) { - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist() } @@ -281,7 +281,7 @@ class WindowOperationsSuite extends TestSuiteBase { slideTime: Time = Seconds(1) ) { test("reduceByKeyAndWindowInv - " + name) { - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime) .persist() -- cgit v1.2.3 From 3dc87dd923578f20f2c6945be7d8c03797e76237 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 1 Jan 2013 16:38:04 -0800 Subject: Fixed compilation bug in RDDSuite created during merge for mesos/master. --- core/src/test/scala/spark/RDDSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index eab09956bb..e5a59dc7d6 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -105,9 +105,9 @@ class RDDSuite extends FunSuite with BeforeAndAfter { sc = new SparkContext("local", "test") val onlySplit = new Split { override def index: Int = 0 } var shouldFail = true - val rdd = new RDD[Int](sc) { - override def splits: Array[Split] = Array(onlySplit) - override val dependencies = List[Dependency[_]]() + val rdd = new RDD[Int](sc, Nil) { + override def getSplits: Array[Split] = Array(onlySplit) + override val getDependencies = List[Dependency[_]]() override def compute(split: Split, context: TaskContext): Iterator[Int] = { if (shouldFail) { throw new Exception("injected failure") -- cgit v1.2.3 From ecf9c0890160c69f1b64b36fa8fdea2f6dd973eb Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 5 Jan 2013 20:54:08 -0500 Subject: Fix Accumulators in Java, and add a test for them --- core/src/main/scala/spark/Accumulators.scala | 18 ++++++++- core/src/main/scala/spark/SparkContext.scala | 7 ++-- .../scala/spark/api/java/JavaSparkContext.scala | 23 +++++++---- core/src/test/scala/spark/JavaAPISuite.java | 44 ++++++++++++++++++++++ 4 files changed, 79 insertions(+), 13 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index bacd0ace37..6280f25391 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -38,14 +38,28 @@ class Accumulable[R, T] ( */ def += (term: T) { value_ = param.addAccumulator(value_, term) } + /** + * Add more data to this accumulator / accumulable + * @param term the data to add + */ + def add(term: T) { value_ = param.addAccumulator(value_, term) } + /** * Merge two accumulable objects together - * + * * Normally, a user will not want to use this version, but will instead call `+=`. - * @param term the other Accumulable that will get merged with this + * @param term the other `R` that will get merged with this */ def ++= (term: R) { value_ = param.addInPlace(value_, term)} + /** + * Merge two accumulable objects together + * + * Normally, a user will not want to use this version, but will instead call `add`. + * @param term the other `R` that will get merged with this + */ + def merge(term: R) { value_ = param.addInPlace(value_, term)} + /** * Access the accumulator's current value; only allowed on master. */ diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4fd81bc63b..bbf8272eb3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -382,11 +382,12 @@ class SparkContext( new Accumulator(initialValue, param) /** - * Create an [[spark.Accumulable]] shared variable, with a `+=` method + * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. + * Only the master can access the accumuable's `value`. * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ - def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = + def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = new Accumulable(initialValue, param) /** @@ -404,7 +405,7 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal) + def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** * Add a file to be downloaded into the working directory of this Spark job on every node. diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index b7725313c4..bf9ad7a200 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import spark.{Accumulator, AccumulatorParam, RDD, SparkContext} +import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext} import spark.SparkContext.IntAccumulatorParam import spark.SparkContext.DoubleAccumulatorParam import spark.broadcast.Broadcast @@ -265,25 +265,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def intAccumulator(initialValue: Int): Accumulator[Int] = - sc.accumulator(initialValue)(IntAccumulatorParam) + def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = + sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] /** * Create an [[spark.Accumulator]] double variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def doubleAccumulator(initialValue: Double): Accumulator[Double] = - sc.accumulator(initialValue)(DoubleAccumulatorParam) + def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = + sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) + /** + * Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can + * "add" values with `add`. Only the master can access the accumuable's `value`. + */ + def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = + sc.accumulable(initialValue)(param) + /** * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 33d5fc2d89..b99e790093 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -581,4 +581,48 @@ public class JavaAPISuite implements Serializable { JavaPairRDD zipped = rdd.zip(doubles); zipped.count(); } + + @Test + public void accumulators() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + + final Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + intAccum.add(x); + } + }); + Assert.assertEquals((Integer) 25, intAccum.value()); + + final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + doubleAccum.add((double) x); + } + }); + Assert.assertEquals((Double) 25.0, doubleAccum.value()); + + // Try a custom accumulator type + AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + public Float addInPlace(Float r, Float t) { + return r + t; + } + + public Float addAccumulator(Float r, Float t) { + return r + t; + } + + public Float zero(Float initialValue) { + return 0.0f; + } + }; + + final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + floatAccum.add((float) x); + } + }); + Assert.assertEquals((Float) 25.0f, floatAccum.value()); + } } -- cgit v1.2.3 From 1346126485444afc065bf4951c4bedebe5c95ce4 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 12:11:27 -0800 Subject: Changed cleanup to clearOldValues for TimeStampedHashMap and TimeStampedHashSet. --- core/src/main/scala/spark/CacheTracker.scala | 4 ++-- core/src/main/scala/spark/MapOutputTracker.scala | 4 ++-- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 6 +++--- core/src/main/scala/spark/scheduler/ResultTask.scala | 2 +- core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 2 +- core/src/main/scala/spark/util/TimeStampedHashMap.scala | 7 +++++-- core/src/main/scala/spark/util/TimeStampedHashSet.scala | 5 ++++- 7 files changed, 18 insertions(+), 12 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 7d320c4fe5..86ad737583 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -39,7 +39,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] - private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.cleanup) + private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.clearOldValues) private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) @@ -120,7 +120,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b // Remembers which splits are currently being loaded (on worker nodes) val loading = new HashSet[String] - val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.cleanup) + val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.clearOldValues) // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 5ebdba0fc8..a2fa2d1ea7 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -178,8 +178,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } def cleanup(cleanupTime: Long) { - mapStatuses.cleanup(cleanupTime) - cachedSerializedStatuses.cleanup(cleanupTime) + mapStatuses.clearOldValues(cleanupTime) + cachedSerializedStatuses.clearOldValues(cleanupTime) } def stop() { diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 9387ba19a3..59f2099e91 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -599,15 +599,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def cleanup(cleanupTime: Long) { var sizeBefore = idToStage.size - idToStage.cleanup(cleanupTime) + idToStage.clearOldValues(cleanupTime) logInfo("idToStage " + sizeBefore + " --> " + idToStage.size) sizeBefore = shuffleToMapStage.size - shuffleToMapStage.cleanup(cleanupTime) + shuffleToMapStage.clearOldValues(cleanupTime) logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) sizeBefore = pendingTasks.size - pendingTasks.cleanup(cleanupTime) + pendingTasks.clearOldValues(cleanupTime) logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) } diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index 7ec6564105..74a63c1af1 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -12,7 +12,7 @@ private[spark] object ResultTask { // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.cleanup) + val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues) def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index feb63abb61..19f5328eee 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -23,7 +23,7 @@ private[spark] object ShuffleMapTask { // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.cleanup) + val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index 7e785182ea..bb7c5c01c8 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -7,7 +7,7 @@ import scala.collection.mutable.Map /** * This is a custom implementation of scala.collection.mutable.Map which stores the insertion * time stamp along with each key-value pair. Key-value pairs that are older than a particular - * threshold time can them be removed using the cleanup method. This is intended to be a drop-in + * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in * replacement of scala.collection.mutable.HashMap. */ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { @@ -74,7 +74,10 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { } } - def cleanup(threshTime: Long) { + /** + * Removes old key-value pairs that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { val iterator = internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() diff --git a/core/src/main/scala/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/spark/util/TimeStampedHashSet.scala index 539dd75844..5f1cc93752 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashSet.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashSet.scala @@ -52,7 +52,10 @@ class TimeStampedHashSet[A] extends Set[A] { } } - def cleanup(threshTime: Long) { + /** + * Removes old values that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { val iterator = internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() -- cgit v1.2.3 From 237bac36e9dca8828192994dad323b8da1619267 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 14:37:21 -0800 Subject: Renamed examples and added documentation. --- core/src/main/scala/spark/RDDCheckpointData.scala | 4 +- docs/streaming-programming-guide.md | 14 ++-- .../spark/streaming/examples/FileStream.scala | 46 ------------- .../examples/FileStreamWithCheckpoint.scala | 75 ---------------------- .../spark/streaming/examples/FlumeEventCount.scala | 2 +- .../scala/spark/streaming/examples/GrepRaw.scala | 32 --------- .../spark/streaming/examples/HdfsWordCount.scala | 36 +++++++++++ .../streaming/examples/NetworkWordCount.scala | 36 +++++++++++ .../spark/streaming/examples/RawNetworkGrep.scala | 46 +++++++++++++ .../streaming/examples/TopKWordCountRaw.scala | 49 -------------- .../spark/streaming/examples/WordCountHdfs.scala | 25 -------- .../streaming/examples/WordCountNetwork.scala | 25 -------- .../spark/streaming/examples/WordCountRaw.scala | 43 ------------- .../scala/spark/streaming/StreamingContext.scala | 38 ++++++++--- .../spark/streaming/dstream/FileInputDStream.scala | 16 ++--- .../scala/spark/streaming/InputStreamsSuite.scala | 2 +- 16 files changed, 163 insertions(+), 326 deletions(-) delete mode 100644 examples/src/main/scala/spark/streaming/examples/FileStream.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/GrepRaw.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala (limited to 'core') diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index 7af830940f..e270b6312e 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -65,7 +65,7 @@ extends Logging with Serializable { cpRDD = Some(newRDD) rdd.changeDependencies(newRDD) cpState = Checkpointed - RDDCheckpointData.checkpointCompleted() + RDDCheckpointData.clearTaskCaches() logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } } @@ -90,7 +90,7 @@ extends Logging with Serializable { } private[spark] object RDDCheckpointData { - def checkpointCompleted() { + def clearTaskCaches() { ShuffleMapTask.clearCache() ResultTask.clearCache() } diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index fc2ea2ef79..05a88ce7bd 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -187,8 +187,8 @@ Conversely, the computation can be stopped by using ssc.stop() {% endhighlight %} -# Example - WordCountNetwork.scala -A good example to start off is the spark.streaming.examples.WordCountNetwork. This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in /streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala. +# Example - NetworkWordCount.scala +A good example to start off is the spark.streaming.examples.NetworkWordCount. This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in /streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala. {% highlight scala %} import spark.streaming.{Seconds, StreamingContext} @@ -196,7 +196,7 @@ import spark.streaming.StreamingContext._ ... // Create the context and set up a network input stream to receive from a host:port -val ssc = new StreamingContext(args(0), "WordCountNetwork", Seconds(1)) +val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1)) val lines = ssc.networkTextStream(args(1), args(2).toInt) // Split the lines into words, count them, and print some of the counts on the master @@ -214,13 +214,13 @@ To run this example on your local machine, you need to first run a Netcat server $ nc -lk 9999 {% endhighlight %} -Then, in a different terminal, you can start WordCountNetwork by using +Then, in a different terminal, you can start NetworkWordCount by using {% highlight bash %} -$ ./run spark.streaming.examples.WordCountNetwork local[2] localhost 9999 +$ ./run spark.streaming.examples.NetworkWordCount local[2] localhost 9999 {% endhighlight %} -This will make WordCountNetwork connect to the netcat server. Any lines typed in the terminal running the netcat server will be counted and printed on screen. +This will make NetworkWordCount connect to the netcat server. Any lines typed in the terminal running the netcat server will be counted and printed on screen.
print() Prints the contents of this DStream on the driver. At each interval, this will take at most ten elements from the DStream's RDD and print them. Prints first ten elements of every batch of data in a DStream on the driver.
saveAsObjectFile(prefix, [suffix]) saveAsObjectFiles(prefix, [suffix]) Save this DStream's contents as a SequenceFile of serialized objects. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
saveAsTextFile(prefix, suffix) saveAsTextFiles(prefix, [suffix]) Save this DStream's contents as a text files. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
saveAsHadoopFiles(prefix, suffix) saveAsHadoopFiles(prefix, [suffix]) Save this DStream's contents as a Hadoop file. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
- + - + @@ -88,55 +88,60 @@ DStreams support many of the transformations available on normal Spark RDD's: - + - + + + + +
@@ -240,7 +240,7 @@ hello world {% highlight bash %} -# TERMINAL 2: RUNNING WordCountNetwork +# TERMINAL 2: RUNNING NetworkWordCount ... 2012-12-31 18:47:10,446 INFO SparkContext: Job finished: run at ThreadPoolExecutor.java:886, took 0.038817 s ------------------------------------------- diff --git a/examples/src/main/scala/spark/streaming/examples/FileStream.scala b/examples/src/main/scala/spark/streaming/examples/FileStream.scala deleted file mode 100644 index 81938d30d4..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/FileStream.scala +++ /dev/null @@ -1,46 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.StreamingContext -import spark.streaming.StreamingContext._ -import spark.streaming.Seconds -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration - - -object FileStream { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: FileStream ") - System.exit(1) - } - - // Create the context - val ssc = new StreamingContext(args(0), "FileStream", Seconds(1)) - - // Create the new directory - val directory = new Path(args(1)) - val fs = directory.getFileSystem(new Configuration()) - if (fs.exists(directory)) throw new Exception("This directory already exists") - fs.mkdirs(directory) - fs.deleteOnExit(directory) - - // Create the FileInputDStream on the directory and use the - // stream to count words in new files created - val inputStream = ssc.textFileStream(directory.toString) - val words = inputStream.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - - // Creating new files in the directory - val text = "This is a text file" - for (i <- 1 to 30) { - ssc.sc.parallelize((1 to (i * 10)).map(_ => text), 10) - .saveAsTextFile(new Path(directory, i.toString).toString) - Thread.sleep(1000) - } - Thread.sleep(5000) // Waiting for the file to be processed - ssc.stop() - System.exit(0) - } -} \ No newline at end of file diff --git a/examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala deleted file mode 100644 index b7bc15a1d5..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala +++ /dev/null @@ -1,75 +0,0 @@ -package spark.streaming.examples - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration - -object FileStreamWithCheckpoint { - - def main(args: Array[String]) { - - if (args.size != 3) { - println("FileStreamWithCheckpoint ") - println("FileStreamWithCheckpoint restart ") - System.exit(-1) - } - - val directory = new Path(args(1)) - val checkpointDir = args(2) - - val ssc: StreamingContext = { - - if (args(0) == "restart") { - - // Recreated streaming context from specified checkpoint file - new StreamingContext(checkpointDir) - - } else { - - // Create directory if it does not exist - val fs = directory.getFileSystem(new Configuration()) - if (!fs.exists(directory)) fs.mkdirs(directory) - - // Create new streaming context - val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint", Seconds(1)) - ssc_.checkpoint(checkpointDir) - - // Setup the streaming computation - val inputStream = ssc_.textFileStream(directory.toString) - val words = inputStream.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - - ssc_ - } - } - - // Start the stream computation - startFileWritingThread(directory.toString) - ssc.start() - } - - def startFileWritingThread(directory: String) { - - val fs = new Path(directory).getFileSystem(new Configuration()) - - val fileWritingThread = new Thread() { - override def run() { - val r = new scala.util.Random() - val text = "This is a sample text file with a random number " - while(true) { - val number = r.nextInt() - val file = new Path(directory, number.toString) - val fos = fs.create(file) - fos.writeChars(text + number) - fos.close() - println("Created text file " + file) - Thread.sleep(1000) - } - } - } - fileWritingThread.start() - } - -} diff --git a/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala index e60ce483a3..461929fba2 100644 --- a/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala +++ b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala @@ -5,7 +5,7 @@ import spark.storage.StorageLevel import spark.streaming._ /** - * Produce a streaming count of events received from Flume. + * Produces a count of events received from Flume. * * This should be used in conjunction with an AvroSink in Flume. It will start * an Avro server on at the request host:port address and listen for requests. diff --git a/examples/src/main/scala/spark/streaming/examples/GrepRaw.scala b/examples/src/main/scala/spark/streaming/examples/GrepRaw.scala deleted file mode 100644 index 812faa368a..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ /dev/null @@ -1,32 +0,0 @@ -package spark.streaming.examples - -import spark.util.IntParam -import spark.storage.StorageLevel - -import spark.streaming._ -import spark.streaming.util.RawTextHelper._ - -object GrepRaw { - def main(args: Array[String]) { - if (args.length != 5) { - System.err.println("Usage: GrepRaw ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args - - // Create the context - val ssc = new StreamingContext(master, "GrepRaw", Milliseconds(batchMillis)) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - warmUp(ssc.sc) - - - val rawStreams = (1 to numStreams).map(_ => - ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray - val union = ssc.union(rawStreams) - union.filter(_.contains("Alice")).count().foreach(r => - println("Grep count: " + r.collect().mkString)) - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala b/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala new file mode 100644 index 0000000000..8530f5c175 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala @@ -0,0 +1,36 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + + +/** + * Counts words in new text files created in the given directory + * Usage: HdfsWordCount + * is the Spark master URL. + * is the directory that Spark Streaming will use to find and read new text files. + * + * To run this on your local machine on directory `localdir`, run this example + * `$ ./run spark.streaming.examples.HdfsWordCount local[2] localdir` + * Then create a text file in `localdir` and the words in the file will get counted. + */ +object HdfsWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: HdfsWordCount ") + System.exit(1) + } + + // Create the context + val ssc = new StreamingContext(args(0), "HdfsWordCount", Seconds(2)) + + // Create the FileInputDStream on the directory and use the + // stream to count words in new files created + val lines = ssc.textFileStream(args(1)) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} + diff --git a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala new file mode 100644 index 0000000000..43c01d5db2 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala @@ -0,0 +1,36 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Usage: NetworkWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ ./run spark.streaming.examples.NetworkWordCount local[2] localhost 9999` + */ +object NetworkWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: NetworkWordCount \n" + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1)) + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + val lines = ssc.networkTextStream(args(1), args(2).toInt) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala new file mode 100644 index 0000000000..2eec777c54 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala @@ -0,0 +1,46 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel + +import spark.streaming._ +import spark.streaming.util.RawTextHelper + +/** + * Receives text from multiple rawNetworkStreams and counts how many '\n' delimited + * lines have the word 'the' in them. This is useful for benchmarking purposes. This + * will only work with spark.streaming.util.RawTextSender running on all worker nodes + * and with Spark using Kryo serialization (set Java property "spark.serializer" to + * "spark.KryoSerializer"). + * Usage: RawNetworkGrep + * is the Spark master URL + * is the number rawNetworkStreams, which should be same as number + * of work nodes in the cluster + * is "localhost". + * is the port on which RawTextSender is running in the worker nodes. + * is the Spark Streaming batch duration in milliseconds. + */ + +object RawNetworkGrep { + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("Usage: RawNetworkGrep ") + System.exit(1) + } + + val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args + + // Create the context + val ssc = new StreamingContext(master, "RawNetworkGrep", Milliseconds(batchMillis)) + + // Warm up the JVMs on master and slave for JIT compilation to kick in + RawTextHelper.warmUp(ssc.sc) + + val rawStreams = (1 to numStreams).map(_ => + ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray + val union = ssc.union(rawStreams) + union.filter(_.contains("the")).count().foreach(r => + println("Grep count: " + r.collect().mkString)) + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala deleted file mode 100644 index 338834bc3c..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ /dev/null @@ -1,49 +0,0 @@ -package spark.streaming.examples - -import spark.storage.StorageLevel -import spark.util.IntParam - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import spark.streaming.util.RawTextHelper._ - -import java.util.UUID - -object TopKWordCountRaw { - - def main(args: Array[String]) { - if (args.length != 4) { - System.err.println("Usage: WordCountRaw <# streams> ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args - val k = 10 - - // Create the context, and set the checkpoint directory. - // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts - // periodically to HDFS - val ssc = new StreamingContext(master, "TopKWordCountRaw", Seconds(1)) - ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - /*warmUp(ssc.sc)*/ - - // Set up the raw network streams that will connect to localhost:port to raw test - // senders on the slaves and generate top K words of last 30 seconds - val lines = (1 to numStreams).map(_ => { - ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) - }) - val union = ssc.union(lines) - val counts = union.mapPartitions(splitAndCountPartitions) - val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreach(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " words from partial top words") - println("Top " + k + " words are " + topK(collectedCounts.toIterator, k).mkString(",")) - }) - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala b/examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala deleted file mode 100644 index 867a8f42c4..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala +++ /dev/null @@ -1,25 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - -object WordCountHdfs { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: WordCountHdfs ") - System.exit(1) - } - - // Create the context - val ssc = new StreamingContext(args(0), "WordCountHdfs", Seconds(2)) - - // Create the FileInputDStream on the directory and use the - // stream to count words in new files created - val lines = ssc.textFileStream(args(1)) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } -} - diff --git a/examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala b/examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala deleted file mode 100644 index eadda60563..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala +++ /dev/null @@ -1,25 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - -object WordCountNetwork { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: WordCountNetwork \n" + - "In local mode, should be 'local[n]' with n > 1") - System.exit(1) - } - - // Create the context and set the batch size - val ssc = new StreamingContext(args(0), "WordCountNetwork", Seconds(1)) - - // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.networkTextStream(args(1), args(2).toInt) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala deleted file mode 100644 index d93335a8ce..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ /dev/null @@ -1,43 +0,0 @@ -package spark.streaming.examples - -import spark.storage.StorageLevel -import spark.util.IntParam - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import spark.streaming.util.RawTextHelper._ - -import java.util.UUID - -object WordCountRaw { - - def main(args: Array[String]) { - if (args.length != 4) { - System.err.println("Usage: WordCountRaw <# streams> ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args - - // Create the context, and set the checkpoint directory. - // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts - // periodically to HDFS - val ssc = new StreamingContext(master, "WordCountRaw", Seconds(1)) - ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - warmUp(ssc.sc) - - // Set up the raw network streams that will connect to localhost:port to raw test - // senders on the slaves and generate count of words of last 30 seconds - val lines = (1 to numStreams).map(_ => { - ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) - }) - val union = ssc.union(lines) - val counts = union.mapPartitions(splitAndCountPartitions) - val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) - windowedCounts.foreach(r => println("# unique words = " + r.count())) - - ssc.start() - } -} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 7256e41af9..215246ba2e 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -154,7 +154,7 @@ class StreamingContext private ( storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[T] = { val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } @@ -192,7 +192,7 @@ class StreamingContext private ( storageLevel: StorageLevel ): DStream[T] = { val inputStream = new SocketInputDStream[T](this, hostname, port, converter, storageLevel) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } @@ -208,7 +208,7 @@ class StreamingContext private ( storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): DStream[SparkFlumeEvent] = { val inputStream = new FlumeInputDStream(this, hostname, port, storageLevel) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } @@ -228,13 +228,14 @@ class StreamingContext private ( storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): DStream[T] = { val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } /** * Creates a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them using the given key-value types and input format. + * File names starting with . are ignored. * @param directory HDFS directory to monitor for new file * @tparam K Key type for reading HDFS file * @tparam V Value type for reading HDFS file @@ -244,16 +245,37 @@ class StreamingContext private ( K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K, V]: ClassManifest - ](directory: String): DStream[(K, V)] = { + ] (directory: String): DStream[(K, V)] = { val inputStream = new FileInputDStream[K, V, F](this, directory) - graph.addInputStream(inputStream) + registerInputStream(inputStream) + inputStream + } + + /** + * Creates a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * @param directory HDFS directory to monitor for new file + * @param filter Function to filter paths to process + * @param newFilesOnly Should process only new files and ignore existing files in the directory + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[ + K: ClassManifest, + V: ClassManifest, + F <: NewInputFormat[K, V]: ClassManifest + ] (directory: String, filter: Path => Boolean, newFilesOnly: Boolean): DStream[(K, V)] = { + val inputStream = new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly) + registerInputStream(inputStream) inputStream } + /** * Creates a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them as text files (using key as LongWritable, value - * as Text and input format as TextInputFormat). + * as Text and input format as TextInputFormat). File names starting with . are ignored. * @param directory HDFS directory to monitor for new file */ def textFileStream(directory: String): DStream[String] = { @@ -274,7 +296,7 @@ class StreamingContext private ( defaultRDD: RDD[T] = null ): DStream[T] = { val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala index cf72095324..1e6ad84b44 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala @@ -14,7 +14,7 @@ private[streaming] class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( @transient ssc_ : StreamingContext, directory: String, - filter: PathFilter = FileInputDStream.defaultPathFilter, + filter: Path => Boolean = FileInputDStream.defaultFilter, newFilesOnly: Boolean = true) extends InputDStream[(K, V)](ssc_) { @@ -60,7 +60,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K val latestModTimeFiles = new HashSet[String]() def accept(path: Path): Boolean = { - if (!filter.accept(path)) { + if (!filter(path)) { return false } else { val modTime = fs.getFileStatus(path).getModificationTime() @@ -95,16 +95,8 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K } } +private[streaming] object FileInputDStream { - val defaultPathFilter = new PathFilter with Serializable { - def accept(path: Path): Boolean = { - val file = path.getName() - if (file.startsWith(".") || file.endsWith("_tmp")) { - return false - } else { - return true - } - } - } + def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") } diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 76b528bec3..00ee903c1e 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -318,7 +318,7 @@ class TestServer(port: Int) extends Logging { } } } catch { - case e: SocketException => println(e) + case e: SocketException => logInfo(e) } finally { logInfo("Connection closed") if (!clientSocket.isClosed) clientSocket.close() -- cgit v1.2.3 From 3b0a3b89ac508b57b8afbd1ca7024ee558a5d1af Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 14:55:49 -0800 Subject: Added better docs for RDDCheckpointData --- core/src/main/scala/spark/RDDCheckpointData.scala | 10 +++++++++- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 4 ++++ 2 files changed, 13 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index e270b6312e..d845a522e4 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -14,15 +14,23 @@ private[spark] object CheckpointState extends Enumeration { } /** - * This class contains all the information of the regarding RDD checkpointing. + * This class contains all the information related to RDD checkpointing. Each instance of this class + * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as, + * manages the post-checkpoint state by providing the updated splits, iterator and preferred locations + * of the checkpointed RDD. */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) extends Logging with Serializable { import CheckpointState._ + // The checkpoint state of the associated RDD. var cpState = Initialized + + // The file to which the associated RDD has been checkpointed to @transient var cpFile: Option[String] = None + + // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. @transient var cpRDD: Option[RDD[T]] = None // Mark the RDD for checkpointing diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 1a88d402c3..86c63ca2f4 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -13,6 +13,10 @@ private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends override val index: Int = idx } +/** + * This RDD represents a RDD checkpoint file (similar to HadoopRDD). + */ +private[spark] class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) extends RDD[T](sc, Nil) { -- cgit v1.2.3 From 4719e6d8fe6d93734f5bbe6c91dcc4616c1ed317 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 16:06:07 -0800 Subject: Changed locations for unit test logs. --- bagel/src/test/resources/log4j.properties | 4 ++-- core/src/test/resources/log4j.properties | 4 ++-- repl/src/test/resources/log4j.properties | 4 ++-- streaming/src/test/resources/log4j.properties | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) (limited to 'core') diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties index 4c99e450bc..83d05cab2f 100644 --- a/bagel/src/test/resources/log4j.properties +++ b/bagel/src/test/resources/log4j.properties @@ -1,8 +1,8 @@ -# Set everything to be logged to the console +# Set everything to be logged to the file bagel/target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false -log4j.appender.file.file=spark-tests.log +log4j.appender.file.file=bagel/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index 5ed388e91b..6ec89c0184 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -1,8 +1,8 @@ -# Set everything to be logged to the file spark-tests.log +# Set everything to be logged to the file core/target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false -log4j.appender.file.file=spark-tests.log +log4j.appender.file.file=core/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties index 4c99e450bc..cfb1a390e6 100644 --- a/repl/src/test/resources/log4j.properties +++ b/repl/src/test/resources/log4j.properties @@ -1,8 +1,8 @@ -# Set everything to be logged to the console +# Set everything to be logged to the repl/target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false -log4j.appender.file.file=spark-tests.log +log4j.appender.file.file=repl/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 33bafebaab..edfa1243fa 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -1,8 +1,8 @@ -# Set everything to be logged to the file streaming-tests.log +# Set everything to be logged to the file streaming/target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false -log4j.appender.file.file=streaming-tests.log +log4j.appender.file.file=streaming/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n -- cgit v1.2.3 From e3861ae3953d7cab66160833688c8baf84e835ad Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Wed, 9 Jan 2013 17:03:25 -0600 Subject: Provide and expose a default Hadoop Configuration. Any "hadoop.*" system properties will be passed along into configuration. --- core/src/main/scala/spark/SparkContext.scala | 18 ++++++++++++++---- .../main/scala/spark/api/java/JavaSparkContext.scala | 7 +++++++ 2 files changed, 21 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bbf8272eb3..36e0938854 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -187,6 +187,18 @@ class SparkContext( private var dagScheduler = new DAGScheduler(taskScheduler) + /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ + val hadoopConfiguration = { + val conf = new Configuration() + // Copy any "hadoop.foo=bar" system properties into conf as "foo=bar" + for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("hadoop.")) { + conf.set(key.substring("hadoop.".length), System.getProperty(key)) + } + val bufferSize = System.getProperty("spark.buffer.size", "65536") + conf.set("io.file.buffer.size", bufferSize) + conf + } + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ @@ -231,10 +243,8 @@ class SparkContext( valueClass: Class[V], minSplits: Int = defaultMinSplits ) : RDD[(K, V)] = { - val conf = new JobConf() + val conf = new JobConf(hadoopConfiguration) FileInputFormat.setInputPaths(conf, path) - val bufferSize = System.getProperty("spark.buffer.size", "65536") - conf.set("io.file.buffer.size", bufferSize) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) } @@ -276,7 +286,7 @@ class SparkContext( fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], vm.erasure.asInstanceOf[Class[V]], - new Configuration) + hadoopConfiguration) } /** diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index 88ab2846be..12e2a0bdac 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -355,6 +355,13 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def clearFiles() { sc.clearFiles() } + + /** + * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. + */ + def hadoopConfiguration() { + sc.hadoopConfiguration + } } object JavaSparkContext { -- cgit v1.2.3 From b15e8512793475eaeda7225a259db8aacd600741 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Thu, 10 Jan 2013 10:55:41 -0600 Subject: Check for AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY environment variables. For custom properties, use "spark.hadoop.*" as a prefix instead of just "hadoop.*". --- core/src/main/scala/spark/SparkContext.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 36e0938854..7b11955f1e 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -190,9 +190,16 @@ class SparkContext( /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = { val conf = new Configuration() - // Copy any "hadoop.foo=bar" system properties into conf as "foo=bar" - for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("hadoop.")) { - conf.set(key.substring("hadoop.".length), System.getProperty(key)) + // Explicitly check for S3 environment variables + if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { + conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + } + // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" + for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("spark.hadoop.")) { + conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) } val bufferSize = System.getProperty("spark.buffer.size", "65536") conf.set("io.file.buffer.size", bufferSize) -- cgit v1.2.3 From d1864052c58ff1e58980729f7ccf00e630f815b9 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Thu, 10 Jan 2013 12:16:26 -0600 Subject: Fix invalid asInstanceOf cast. --- core/src/main/scala/spark/SparkContext.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 7b11955f1e..d2a5b4757a 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -7,6 +7,7 @@ import java.net.{URI, URLClassLoader} import scala.collection.Map import scala.collection.generic.Growable import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.JavaConversions._ import akka.actor.Actor import akka.actor.Actor._ @@ -198,7 +199,7 @@ class SparkContext( conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" - for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("spark.hadoop.")) { + for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) { conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) } val bufferSize = System.getProperty("spark.buffer.size", "65536") -- cgit v1.2.3 From 3e6519a36e354f3623c5b968efe5217c7fcb242f Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 11 Jan 2013 11:24:20 -0600 Subject: Use hadoopConfiguration for default JobConf in PairRDDFunctions. --- core/src/main/scala/spark/PairRDDFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index ce48cea903..51c15837c4 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -557,7 +557,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf) { + conf: JobConf = new JobConf(self.context.hadoopConfiguration)) { conf.setOutputKeyClass(keyClass) conf.setOutputValueClass(valueClass) // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug -- cgit v1.2.3 From 5c7a1272198c88a90a843bbda0c1424f92b7c12e Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 11 Jan 2013 11:25:11 -0600 Subject: Pass a new Configuration that wraps the default hadoopConfiguration. --- core/src/main/scala/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d2a5b4757a..f6b98c41bc 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -294,7 +294,7 @@ class SparkContext( fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], vm.erasure.asInstanceOf[Class[V]], - hadoopConfiguration) + new Configuration(hadoopConfiguration)) } /** -- cgit v1.2.3 From 22445fbea9ed1575e49a1f9bb2251d98a57b9e4e Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Fri, 11 Jan 2013 13:30:49 -0800 Subject: attempt to sleep for more accurate time period, minor cleanup --- .../scala/spark/util/RateLimitedOutputStream.scala | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index d11ed163ce..3050213709 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -1,8 +1,10 @@ package spark.util import java.io.OutputStream +import java.util.concurrent.TimeUnit._ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { + val SyncIntervalNs = NANOSECONDS.convert(10, SECONDS) var lastSyncTime = System.nanoTime() var bytesWrittenSinceSync: Long = 0 @@ -28,20 +30,21 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu def waitToWrite(numBytes: Int) { while (true) { - val now = System.nanoTime() - val elapsed = math.max(now - lastSyncTime, 1) - val rate = bytesWrittenSinceSync.toDouble / (elapsed / 1.0e9) + val now = System.nanoTime + val elapsedSecs = SECONDS.convert(max(now - lastSyncTime, 1), NANOSECONDS) + val rate = bytesWrittenSinceSync.toDouble / elapsedSecs if (rate < bytesPerSec) { // It's okay to write; just update some variables and return bytesWrittenSinceSync += numBytes - if (now > lastSyncTime + (1e10).toLong) { - // Ten seconds have passed since lastSyncTime; let's resync + if (now > lastSyncTime + SyncIntervalNs) { + // Sync interval has passed; let's resync lastSyncTime = now bytesWrittenSinceSync = numBytes } - return } else { - Thread.sleep(5) + // Calculate how much time we should sleep to bring ourselves to the desired rate. + val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) + if (sleepTime > 0) Thread.sleep(sleepTime) } } } @@ -53,4 +56,4 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu override def close() { out.close() } -} \ No newline at end of file +} -- cgit v1.2.3 From ff10b3aa0970cc7224adc6bc73d99a7ffa30219f Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Fri, 11 Jan 2013 21:03:57 -0800 Subject: add missing return --- core/src/main/scala/spark/util/RateLimitedOutputStream.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index 3050213709..ed459c2544 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -41,6 +41,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu lastSyncTime = now bytesWrittenSinceSync = numBytes } + return } else { // Calculate how much time we should sleep to bring ourselves to the desired rate. val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) -- cgit v1.2.3 From 0cfea7a2ec467717fbe110f9b15163bea2719575 Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Fri, 11 Jan 2013 23:48:07 -0800 Subject: add unit test --- .../scala/spark/util/RateLimitedOutputStream.scala | 2 +- .../spark/util/RateLimitedOutputStreamSuite.scala | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index ed459c2544..16db7549b2 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -31,7 +31,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu def waitToWrite(numBytes: Int) { while (true) { val now = System.nanoTime - val elapsedSecs = SECONDS.convert(max(now - lastSyncTime, 1), NANOSECONDS) + val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) val rate = bytesWrittenSinceSync.toDouble / elapsedSecs if (rate < bytesPerSec) { // It's okay to write; just update some variables and return diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala new file mode 100644 index 0000000000..1dc45e0433 --- /dev/null +++ b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala @@ -0,0 +1,22 @@ +package spark.util + +import org.scalatest.FunSuite +import java.io.ByteArrayOutputStream +import java.util.concurrent.TimeUnit._ + +class RateLimitedOutputStreamSuite extends FunSuite { + + private def benchmark[U](f: => U): Long = { + val start = System.nanoTime + f + System.nanoTime - start + } + + test("write") { + val underlying = new ByteArrayOutputStream + val data = "X" * 1000 + val stream = new RateLimitedOutputStream(underlying, 100) + val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } + assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) + } +} -- cgit v1.2.3 From 2c77eeebb66a3d1337d45b5001be2b48724f9fd5 Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sat, 12 Jan 2013 00:13:45 -0800 Subject: correct test params --- core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala index 1dc45e0433..b392075482 100644 --- a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala +++ b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala @@ -14,8 +14,8 @@ class RateLimitedOutputStreamSuite extends FunSuite { test("write") { val underlying = new ByteArrayOutputStream - val data = "X" * 1000 - val stream = new RateLimitedOutputStream(underlying, 100) + val data = "X" * 41000 + val stream = new RateLimitedOutputStream(underlying, 10000) val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) } -- cgit v1.2.3 From ea20ae661888d871f70d5ed322cfe924c5a31dba Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sat, 12 Jan 2013 09:18:00 -0800 Subject: add one extra test --- core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala index b392075482..794063fb6d 100644 --- a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala +++ b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala @@ -18,5 +18,6 @@ class RateLimitedOutputStreamSuite extends FunSuite { val stream = new RateLimitedOutputStream(underlying, 10000) val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) + assert(underlying.toString("UTF-8") == data) } } -- cgit v1.2.3 From addff2c466d4b76043e612d4d28ab9de7f003298 Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sat, 12 Jan 2013 09:57:29 -0800 Subject: add comment --- core/src/main/scala/spark/util/RateLimitedOutputStream.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index 16db7549b2..ed3d2b66bb 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -44,6 +44,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu return } else { // Calculate how much time we should sleep to bring ourselves to the desired rate. + // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) if (sleepTime > 0) Thread.sleep(sleepTime) } -- cgit v1.2.3 From 2305a2c1d91273a93ee6b571b0cd4bcaa1b2969d Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sun, 13 Jan 2013 10:01:56 -0800 Subject: more code cleanup --- .../scala/spark/util/RateLimitedOutputStream.scala | 63 +++++++++++----------- 1 file changed, 32 insertions(+), 31 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index ed3d2b66bb..10790a9eee 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -1,11 +1,14 @@ package spark.util +import scala.annotation.tailrec + import java.io.OutputStream import java.util.concurrent.TimeUnit._ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { val SyncIntervalNs = NANOSECONDS.convert(10, SECONDS) - var lastSyncTime = System.nanoTime() + val ChunkSize = 8192 + var lastSyncTime = System.nanoTime var bytesWrittenSinceSync: Long = 0 override def write(b: Int) { @@ -17,37 +20,13 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu write(bytes, 0, bytes.length) } - override def write(bytes: Array[Byte], offset: Int, length: Int) { - val CHUNK_SIZE = 8192 - var pos = 0 - while (pos < length) { - val writeSize = math.min(length - pos, CHUNK_SIZE) + @tailrec + override final def write(bytes: Array[Byte], offset: Int, length: Int) { + val writeSize = math.min(length - offset, ChunkSize) + if (writeSize > 0) { waitToWrite(writeSize) - out.write(bytes, offset + pos, writeSize) - pos += writeSize - } - } - - def waitToWrite(numBytes: Int) { - while (true) { - val now = System.nanoTime - val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) - val rate = bytesWrittenSinceSync.toDouble / elapsedSecs - if (rate < bytesPerSec) { - // It's okay to write; just update some variables and return - bytesWrittenSinceSync += numBytes - if (now > lastSyncTime + SyncIntervalNs) { - // Sync interval has passed; let's resync - lastSyncTime = now - bytesWrittenSinceSync = numBytes - } - return - } else { - // Calculate how much time we should sleep to bring ourselves to the desired rate. - // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) - val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) - if (sleepTime > 0) Thread.sleep(sleepTime) - } + out.write(bytes, offset, writeSize) + write(bytes, offset + writeSize, length) } } @@ -58,4 +37,26 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu override def close() { out.close() } + + @tailrec + private def waitToWrite(numBytes: Int) { + val now = System.nanoTime + val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) + val rate = bytesWrittenSinceSync.toDouble / elapsedSecs + if (rate < bytesPerSec) { + // It's okay to write; just update some variables and return + bytesWrittenSinceSync += numBytes + if (now > lastSyncTime + SyncIntervalNs) { + // Sync interval has passed; let's resync + lastSyncTime = now + bytesWrittenSinceSync = numBytes + } + } else { + // Calculate how much time we should sleep to bring ourselves to the desired rate. + // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) + val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) + if (sleepTime > 0) Thread.sleep(sleepTime) + waitToWrite(numBytes) + } + } } -- cgit v1.2.3 From c31931af7eb01fbe2bb276bb6f428248128832b0 Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sun, 13 Jan 2013 10:39:47 -0800 Subject: switch to uppercase constants --- core/src/main/scala/spark/util/RateLimitedOutputStream.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index 10790a9eee..e3f00ea8c7 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -6,8 +6,8 @@ import java.io.OutputStream import java.util.concurrent.TimeUnit._ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { - val SyncIntervalNs = NANOSECONDS.convert(10, SECONDS) - val ChunkSize = 8192 + val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + val CHUNK_SIZE = 8192 var lastSyncTime = System.nanoTime var bytesWrittenSinceSync: Long = 0 @@ -22,7 +22,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu @tailrec override final def write(bytes: Array[Byte], offset: Int, length: Int) { - val writeSize = math.min(length - offset, ChunkSize) + val writeSize = math.min(length - offset, CHUNK_SIZE) if (writeSize > 0) { waitToWrite(writeSize) out.write(bytes, offset, writeSize) @@ -46,7 +46,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu if (rate < bytesPerSec) { // It's okay to write; just update some variables and return bytesWrittenSinceSync += numBytes - if (now > lastSyncTime + SyncIntervalNs) { + if (now > lastSyncTime + SYNC_INTERVAL) { // Sync interval has passed; let's resync lastSyncTime = now bytesWrittenSinceSync = numBytes -- cgit v1.2.3 From 0dbd411a562396e024c513936fde46b0d2f6d59d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 13 Jan 2013 21:08:35 -0800 Subject: Added documentation for PairDStreamFunctions. --- core/src/main/scala/spark/PairRDDFunctions.scala | 6 +- docs/streaming-programming-guide.md | 45 ++-- .../src/main/scala/spark/streaming/DStream.scala | 35 ++- .../spark/streaming/PairDStreamFunctions.scala | 293 ++++++++++++++++++++- .../scala/spark/streaming/util/RawTextHelper.scala | 2 +- 5 files changed, 331 insertions(+), 50 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 07ae2d647c..d95b66ad78 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -199,9 +199,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. */ def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { this.cogroup(other, partitioner).flatMapValues { diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 05a88ce7bd..b6da7af654 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -43,7 +43,7 @@ A complete list of input sources is available in the [StreamingContext API docum # DStream Operations -Once an input stream has been created, you can transform it using _stream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the stream by writing data out to an external source. +Once an input DStream has been created, you can transform it using _DStream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the DStream by writing data out to an external source. ## Transformations @@ -53,11 +53,11 @@ DStreams support many of the transformations available on normal Spark RDD's:
TransformationMeaning
map(func) Return a new stream formed by passing each element of the source through a function func. Returns a new DStream formed by passing each element of the source through a function func.
filter(func) Return a new stream formed by selecting those elements of the source on which func returns true. Returns a new stream formed by selecting those elements of the source on which func returns true.
flatMap(func)
cogroup(otherStream, [numTasks]) When called on streams of type (K, V) and (K, W), returns a stream of (K, Seq[V], Seq[W]) tuples. This operation is also called groupWith. When called on DStream of type (K, V) and (K, W), returns a DStream of (K, Seq[V], Seq[W]) tuples.
reduce(func) Create a new single-element stream by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel. Returns a new DStream of single-element RDDs by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel.
transform(func) Returns a new DStream by applying func (a RDD-to-RDD function) to every RDD of the stream. This can be used to do arbitrary RDD operations on the DStream.
-Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowTime, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. +Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowDuration, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. - - + - - + - - + -
TransformationMeaning
window(windowTime, slideTime) Return a new stream which is computed based on windowed batches of the source stream. windowTime is the width of the window and slideTime is the frequency during which the window is calculated. Both times must be multiples of the batch interval. + window(windowDuration, slideTime) Return a new stream which is computed based on windowed batches of the source stream. windowDuration is the width of the window and slideTime is the frequency during which the window is calculated. Both times must be multiples of the batch interval.
countByWindow(windowTime, slideTime) Return a sliding count of elements in the stream. windowTime and slideTime are exactly as defined in window(). + countByWindow(windowDuration, slideTime) Return a sliding count of elements in the stream. windowDuration and slideDuration are exactly as defined in window().
reduceByWindow(func, windowTime, slideTime) Return a new single-element stream, created by aggregating elements in the stream over a sliding interval using func. The function should be associative so that it can be computed correctly in parallel. windowTime and slideTime are exactly as defined in window(). + reduceByWindow(func, windowDuration, slideDuration) Return a new single-element stream, created by aggregating elements in the stream over a sliding interval using func. The function should be associative so that it can be computed correctly in parallel. windowDuration and slideDuration are exactly as defined in window().
groupByKeyAndWindow(windowTime, slideTime, [numTasks]) + groupByKeyAndWindow(windowDuration, slideDuration, [numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, Seq[V]) pairs over a sliding window.
-Note: By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional numTasks argument to set a different number of tasks. windowTime and slideTime are exactly as defined in window(). +Note: By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional numTasks argument to set a different number of tasks. windowDuration and slideDuration are exactly as defined in window().
reduceByKeyAndWindow(func, [numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, V) pairs where the values for each key are aggregated using the given reduce function over batches within a sliding window. Like in groupByKeyAndWindow, the number of reduce tasks is configurable through an optional second argument. - windowTime and slideTime are exactly as defined in window(). + windowDuration and slideDuration are exactly as defined in window().
countByKeyAndWindow([numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, Int) pairs where the values for each key are the count within a sliding window. Like in countByKeyAndWindow, the number of reduce tasks is configurable through an optional second argument. - windowTime and slideTime are exactly as defined in window(). + windowDuration and slideDuration are exactly as defined in window().
+A complete list of DStream operations is available in the API documentation of [DStream](api/streaming/index.html#spark.streaming.DStream) and [PairDStreamFunctions](api/streaming/index.html#spark.streaming.PairDStreamFunctions). ## Output Operations When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined: @@ -144,7 +149,7 @@ When an output operator is called, it triggers the computation of a stream. Curr - + @@ -155,18 +160,18 @@ When an output operator is called, it triggers the computation of a stream. Curr - - + - +
OperatorMeaning
foreachRDD(func) foreach(func) The fundamental output operator. Applies a function, func, to each RDD generated from the stream. This function should have side effects, such as printing output, saving the RDD to external files, or writing it over the network to an external system.
saveAsObjectFiles(prefix, [suffix]) Save this DStream's contents as a SequenceFile of serialized objects. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". + Save this DStream's contents as a SequenceFile of serialized objects. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
saveAsTextFiles(prefix, [suffix]) Save this DStream's contents as a text files. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". Save this DStream's contents as a text files. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
saveAsHadoopFiles(prefix, [suffix]) Save this DStream's contents as a Hadoop file. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". Save this DStream's contents as a Hadoop file. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index c89fb7723e..d94548a4f3 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -471,7 +471,7 @@ abstract class DStream[T: ClassManifest] ( * Returns a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Int] = this.map(_ => 1).reduce(_ + _) + def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _) /** * Applies a function to each RDD in this DStream. This is an output operator, so @@ -529,17 +529,16 @@ abstract class DStream[T: ClassManifest] ( * Return a new DStream which is computed based on windowed batches of this DStream. * The new DStream generates RDDs with the same interval as this DStream. * @param windowDuration width of the window; must be a multiple of this DStream's interval. - * @return */ def window(windowDuration: Duration): DStream[T] = window(windowDuration, this.slideDuration) /** * Return a new DStream which is computed based on windowed batches of this DStream. - * @param windowDuration duration (i.e., width) of the window; - * must be a multiple of this DStream's interval + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's interval + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval */ def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = { new WindowedDStream(this, windowDuration, slideDuration) @@ -548,16 +547,22 @@ abstract class DStream[T: ClassManifest] ( /** * Returns a new DStream which computed based on tumbling window on this DStream. * This is equivalent to window(batchTime, batchTime). - * @param batchTime tumbling window duration; must be a multiple of this DStream's interval + * @param batchTime tumbling window duration; must be a multiple of this DStream's + * batching interval */ def tumble(batchTime: Duration): DStream[T] = window(batchTime, batchTime) /** * Returns a new DStream in which each RDD has a single element generated by reducing all - * elements in a window over this DStream. windowDuration and slideDuration are as defined in the - * window() operation. This is equivalent to window(windowDuration, slideDuration).reduce(reduceFunc) + * elements in a window over this DStream. windowDuration and slideDuration are as defined + * in the window() operation. This is equivalent to + * window(windowDuration, slideDuration).reduce(reduceFunc) */ - def reduceByWindow(reduceFunc: (T, T) => T, windowDuration: Duration, slideDuration: Duration): DStream[T] = { + def reduceByWindow( + reduceFunc: (T, T) => T, + windowDuration: Duration, + slideDuration: Duration + ): DStream[T] = { this.window(windowDuration, slideDuration).reduce(reduceFunc) } @@ -577,8 +582,8 @@ abstract class DStream[T: ClassManifest] ( * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the * window() operation. This is equivalent to window(windowDuration, slideDuration).count() */ - def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Int] = { - this.map(_ => 1).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) + def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Long] = { + this.map(_ => 1L).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) } /** @@ -612,6 +617,8 @@ abstract class DStream[T: ClassManifest] ( /** * Saves each RDD in this DStream as a Sequence file of serialized objects. + * The file name at each batch interval is generated based on `prefix` and + * `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsObjectFiles(prefix: String, suffix: String = "") { val saveFunc = (rdd: RDD[T], time: Time) => { @@ -622,7 +629,9 @@ abstract class DStream[T: ClassManifest] ( } /** - * Saves each RDD in this DStream as at text file, using string representation of elements. + * Saves each RDD in this DStream as at text file, using string representation + * of elements. The file name at each batch interval is generated based on + * `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsTextFiles(prefix: String, suffix: String = "") { val saveFunc = (rdd: RDD[T], time: Time) => { diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 482d01300d..3952457339 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -25,34 +25,76 @@ extends Serializable { new HashPartitioner(numPartitions) } + /** + * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs are grouped into a + * single sequence to generate the RDDs of the new DStream. Hash partitioning is + * used to generate the RDDs with Spark's default number of partitions. + */ def groupByKey(): DStream[(K, Seq[V])] = { groupByKey(defaultPartitioner()) } + /** + * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs are grouped into a + * single sequence to generate the RDDs of the new DStream. Hash partitioning is + * used to generate the RDDs with `numPartitions` partitions. + */ def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = { groupByKey(defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs are grouped into a + * single sequence to generate the RDDs of the new DStream. [[spark.Partitioner]] + * is used to control the partitioning of each RDD. + */ def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = { val createCombiner = (v: V) => ArrayBuffer[V](v) val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v) val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2) - combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner).asInstanceOf[DStream[(K, Seq[V])]] + combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner) + .asInstanceOf[DStream[(K, Seq[V])]] } + /** + * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs is merged using the + * associative reduce function to generate the RDDs of the new DStream. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + */ def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = { reduceByKey(reduceFunc, defaultPartitioner()) } + /** + * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs is merged using the + * associative reduce function to generate the RDDs of the new DStream. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + */ def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): DStream[(K, V)] = { reduceByKey(reduceFunc, defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs is merged using the + * associative reduce function to generate the RDDs of the new DStream. + * [[spark.Partitioner]] is used to control the partitioning of each RDD. + */ def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) } + /** + * Generic function to combine elements of each key in DStream's RDDs using custom function. + * This is similar to the combineByKey for RDDs. Please refer to combineByKey in + * [[spark.PairRDDFunctions]] for more information. + */ def combineByKey[C: ClassManifest]( createCombiner: V => C, mergeValue: (C, V) => C, @@ -61,14 +103,52 @@ extends Serializable { new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner) } + /** + * Creates a new DStream by counting the number of values of each key in each RDD + * of `this` DStream. Hash partitioning is used to generate the RDDs with Spark's + * `numPartitions` partitions. + */ def countByKey(numPartitions: Int = self.ssc.sc.defaultParallelism): DStream[(K, Long)] = { self.map(x => (x._1, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) } + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * The new DStream generates RDDs with the same interval as this DStream. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + */ + def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Seq[V])] = { + groupByKeyAndWindow(windowDuration, self.slideDuration, defaultPartitioner()) + } + + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration): DStream[(K, Seq[V])] = { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner()) } + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def groupByKeyAndWindow( windowDuration: Duration, slideDuration: Duration, @@ -77,6 +157,16 @@ extends Serializable { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ def groupByKeyAndWindow( windowDuration: Duration, slideDuration: Duration, @@ -85,6 +175,15 @@ extends Serializable { self.window(windowDuration, slideDuration).groupByKey(partitioner) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * The new DStream generates RDDs with the same interval as this DStream. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration @@ -92,6 +191,17 @@ extends Serializable { reduceByKeyAndWindow(reduceFunc, windowDuration, self.slideDuration, defaultPartitioner()) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration, @@ -100,6 +210,18 @@ extends Serializable { reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner()) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration, @@ -109,6 +231,17 @@ extends Serializable { reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration, @@ -121,12 +254,23 @@ extends Serializable { .reduceByKey(cleanedReduceFunc, partitioner) } - // This method is the efficient sliding window reduce operation, - // which requires the specification of an inverse reduce function, - // so that new elements introduced in the window can be "added" using - // reduceFunc to the previous window's result and old elements can be - // "subtracted using invReduceFunc. - + /** + * Creates a new DStream by reducing over a window in a smarter way. + * The reduced value of over a new window is calculated incrementally by using the + * old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, @@ -138,6 +282,24 @@ extends Serializable { reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner()) } + /** + * Creates a new DStream by reducing over a window in a smarter way. + * The reduced value of over a new window is calculated incrementally by using the + * old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, @@ -150,6 +312,23 @@ extends Serializable { reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by reducing over a window in a smarter way. + * The reduced value of over a new window is calculated incrementally by using the + * old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, @@ -164,6 +343,16 @@ extends Serializable { self, cleanedReduceFunc, cleanedInvReduceFunc, windowDuration, slideDuration, partitioner) } + /** + * Creates a new DStream by counting the number of values for each key over a window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def countByKeyAndWindow( windowDuration: Duration, slideDuration: Duration, @@ -179,17 +368,30 @@ extends Serializable { ) } - // TODO: - // - // - // - // + /** + * Creates a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key from + * `this` DStream. Hash partitioning is used to generate the RDDs with Spark's default + * number of partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @tparam S State type + */ def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S] ): DStream[(K, S)] = { updateStateByKey(updateFunc, defaultPartitioner()) } + /** + * Creates a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key from + * `this` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param numPartitions Number of partitions of each RDD in the new DStream. + * @tparam S State type + */ def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S], numPartitions: Int @@ -197,6 +399,15 @@ extends Serializable { updateStateByKey(updateFunc, defaultPartitioner(numPartitions)) } + /** + * Creates a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key from + * `this` DStream. [[spark.Partitioner]] is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @tparam S State type + */ def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner @@ -207,6 +418,19 @@ extends Serializable { updateStateByKey(newUpdateFunc, partitioner, true) } + /** + * Creates a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key from + * `this` DStream. [[spark.Partitioner]] is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. Note, that + * this function may generate a different a tuple with a different key + * than the input key. It is up to the developer to decide whether to + * remember the partitioner despite the key being changed. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs. + * @tparam S State type + */ def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, @@ -226,10 +450,24 @@ extends Serializable { new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc) } + /** + * Cogroups `this` DStream with `other` DStream. Each RDD of the new DStream will + * be generated by cogrouping RDDs from`this`and `other` DStreams. Therefore, for + * each key k in corresponding RDDs of `this` or `other` DStreams, the generated RDD + * will contains a tuple with the list of values for that key in both RDDs. + * HashPartitioner is used to partition each generated RDD into default number of partitions. + */ def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { cogroup(other, defaultPartitioner()) } + /** + * Cogroups `this` DStream with `other` DStream. Each RDD of the new DStream will + * be generated by cogrouping RDDs from`this`and `other` DStreams. Therefore, for + * each key k in corresponding RDDs of `this` or `other` DStreams, the generated RDD + * will contains a tuple with the list of values for that key in both RDDs. + * Partitioner is used to partition each generated RDD. + */ def cogroup[W: ClassManifest]( other: DStream[(K, W)], partitioner: Partitioner @@ -249,11 +487,24 @@ extends Serializable { } } + /** + * Joins `this` DStream with `other` DStream. Each RDD of the new DStream will + * be generated by joining RDDs from `this` and `other` DStreams. HashPartitioner is used + * to partition each generated RDD into default number of partitions. + */ def join[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (V, W))] = { join[W](other, defaultPartitioner()) } - def join[W: ClassManifest](other: DStream[(K, W)], partitioner: Partitioner): DStream[(K, (V, W))] = { + /** + * Joins `this` DStream with `other` DStream, that is, each RDD of the new DStream will + * be generated by joining RDDs from `this` and other DStream. Uses the given + * Partitioner to partition each generated RDD. + */ + def join[W: ClassManifest]( + other: DStream[(K, W)], + partitioner: Partitioner + ): DStream[(K, (V, W))] = { this.cogroup(other, partitioner) .flatMapValues{ case (vs, ws) => @@ -261,6 +512,10 @@ extends Serializable { } } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + */ def saveAsHadoopFiles[F <: OutputFormat[K, V]]( prefix: String, suffix: String @@ -268,6 +523,10 @@ extends Serializable { saveAsHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + */ def saveAsHadoopFiles( prefix: String, suffix: String, @@ -283,6 +542,10 @@ extends Serializable { self.foreach(saveFunc) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]]( prefix: String, suffix: String @@ -290,6 +553,10 @@ extends Serializable { saveAsNewAPIHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ def saveAsNewAPIHadoopFiles( prefix: String, suffix: String, diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala index f31ae39a16..03749d4a94 100644 --- a/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala @@ -81,7 +81,7 @@ object RawTextHelper { * before real workload starts. */ def warmUp(sc: SparkContext) { - for(i <- 0 to 4) { + for(i <- 0 to 1) { sc.parallelize(1 to 200000, 1000) .map(_ % 1331).map(_.toString) .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) -- cgit v1.2.3 From 131be5d62ef6b770de5106eb268a45bca385b599 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 14 Jan 2013 03:28:25 -0800 Subject: Fixed bug in RDD checkpointing. --- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 86c63ca2f4..6f00f6ac73 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -80,12 +80,12 @@ private[spark] object CheckpointRDD extends Logging { val serializer = SparkEnv.get.serializer.newInstance() val serializeStream = serializer.serializeStream(fileOutputStream) serializeStream.writeAll(iterator) - fileOutputStream.close() + serializeStream.close() if (!fs.rename(tempOutputPath, finalOutputPath)) { if (!fs.delete(finalOutputPath, true)) { throw new IOException("Checkpoint failed: failed to delete earlier output of task " - + context.attemptId); + + context.attemptId) } if (!fs.rename(tempOutputPath, finalOutputPath)) { throw new IOException("Checkpoint failed: failed to save output of task: " @@ -119,7 +119,7 @@ private[spark] object CheckpointRDD extends Logging { val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) val path = new Path(hdfsPath, "temp") val fs = path.getFileSystem(new Configuration()) - sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 10) _) + sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same") assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same") -- cgit v1.2.3 From 273fb5cc109ac0a032f84c1566ae908cd0eb27b6 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Thu, 3 Jan 2013 14:09:56 -0800 Subject: Throw FetchFailedException for cached missing locs --- core/src/main/scala/spark/MapOutputTracker.scala | 36 +++++++++++++++++------- 1 file changed, 26 insertions(+), 10 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 70eb9f702e..9f2aa76830 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -139,8 +139,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea case e: InterruptedException => } } - return mapStatuses.get(shuffleId).map(status => - (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, + mapStatuses.get(shuffleId)) } else { fetching += shuffleId } @@ -156,21 +156,15 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea fetchedStatuses = deserializeStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) - if (fetchedStatuses.contains(null)) { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing an output location for shuffle " + shuffleId)) - } } finally { fetching.synchronized { fetching -= shuffleId fetching.notifyAll() } } - return fetchedStatuses.map(s => - (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } else { - return statuses.map(s => - (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) } } @@ -258,6 +252,28 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea private[spark] object MapOutputTracker { private val LOG_BASE = 1.1 + // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If + // any of the statuses is null (indicating a missing location due to a failed mapper), + // throw a FetchFailedException. + def convertMapStatuses( + shuffleId: Int, + reduceId: Int, + statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + if (statuses == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing all output locations for shuffle " + shuffleId)) + } + statuses.map { + status => + if (status == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing an output location for shuffle " + shuffleId)) + } else { + (status.address, decompressSize(status.compressedSizes(reduceId))) + } + } + } + /** * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. * We do this by encoding the log base 1.1 of the size as an integer, which can support -- cgit v1.2.3 From 7ba34bc007ec10d12b2a871749f32232cdbc0d9c Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 14 Jan 2013 15:24:08 -0800 Subject: Additional tests for MapOutputTracker. --- .../test/scala/spark/MapOutputTrackerSuite.scala | 82 +++++++++++++++++++++- 1 file changed, 80 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 5b4b198960..6c6f82e274 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -1,12 +1,18 @@ package spark import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter import akka.actor._ import spark.scheduler.MapStatus import spark.storage.BlockManagerId +import spark.util.AkkaUtils -class MapOutputTrackerSuite extends FunSuite { +class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { + after { + System.clearProperty("spark.master.port") + } + test("compressSize") { assert(MapOutputTracker.compressSize(0L) === 0) assert(MapOutputTracker.compressSize(1L) === 1) @@ -71,6 +77,78 @@ class MapOutputTrackerSuite extends FunSuite { // The remaining reduce task might try to grab the output dispite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. - intercept[Exception] { tracker.getServerStatuses(10, 1) } + intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } + } + + test("remote fetch") { + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem("test", "localhost", 0) + System.setProperty("spark.master.port", boundPort.toString) + val masterTracker = new MapOutputTracker(actorSystem, true) + val slaveTracker = new MapOutputTracker(actorSystem, false) + masterTracker.registerShuffle(10, 1) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((new BlockManagerId("hostA", 1000), size1000))) + + masterTracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + } + + test("simulatenous fetch fails") { + val dummyActorSystem = ActorSystem("testDummy") + val dummyTracker = new MapOutputTracker(dummyActorSystem, true) + dummyTracker.registerShuffle(10, 1) + // val compressedSize1000 = MapOutputTracker.compressSize(1000L) + // val size100 = MapOutputTracker.decompressSize(compressedSize1000) + // dummyTracker.registerMapOutput(10, 0, new MapStatus( + // new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + val serializedMessage = dummyTracker.getSerializedLocations(10) + + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem("test", "localhost", 0) + System.setProperty("spark.master.port", boundPort.toString) + val delayResponseLock = new java.lang.Object + val delayResponseActor = actorSystem.actorOf(Props(new Actor { + override def receive = { + case GetMapOutputStatuses(shuffleId: Int, requester: String) => + delayResponseLock.synchronized { + sender ! serializedMessage + } + } + }), name = "MapOutputTracker") + val slaveTracker = new MapOutputTracker(actorSystem, false) + var firstFailed = false + var secondFailed = false + val firstFetch = new Thread { + override def run() { + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + firstFailed = true + } + } + val secondFetch = new Thread { + override def run() { + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + secondFailed = true + } + } + delayResponseLock.synchronized { + firstFetch.start + secondFetch.start + } + firstFetch.join + secondFetch.join + assert(firstFailed && secondFailed) } } -- cgit v1.2.3 From b0389997972d383c3aaa87924b725dee70b18d8e Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 14 Jan 2013 17:04:44 -0800 Subject: Fix accidental spark.master.host reuse --- core/src/test/scala/spark/MapOutputTrackerSuite.scala | 2 ++ 1 file changed, 2 insertions(+) (limited to 'core') diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 6c6f82e274..aa1d8ac7e6 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -81,6 +81,7 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { } test("remote fetch") { + System.clearProperty("spark.master.host") val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) System.setProperty("spark.master.port", boundPort.toString) @@ -107,6 +108,7 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { } test("simulatenous fetch fails") { + System.clearProperty("spark.master.host") val dummyActorSystem = ActorSystem("testDummy") val dummyTracker = new MapOutputTracker(dummyActorSystem, true) dummyTracker.registerShuffle(10, 1) -- cgit v1.2.3 From dd583b7ebf0e6620ec8e35424b59db451febe3e8 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 15 Jan 2013 10:52:06 -0600 Subject: Call executeOnCompleteCallbacks in a finally block. --- core/src/main/scala/spark/scheduler/ResultTask.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index e492279b4e..2aad7956b4 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -15,9 +15,11 @@ private[spark] class ResultTask[T, U]( override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) - val result = func(context, rdd.iterator(split, context)) - context.executeOnCompleteCallbacks() - result + try { + func(context, rdd.iterator(split, context)) + } finally { + context.executeOnCompleteCallbacks() + } } override def preferredLocations: Seq[String] = locs -- cgit v1.2.3 From d228bff440395e8e6b8d67483467dde65b08ab40 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 15 Jan 2013 11:48:50 -0600 Subject: Add a test. --- .../scala/spark/scheduler/TaskContextSuite.scala | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 core/src/test/scala/spark/scheduler/TaskContextSuite.scala (limited to 'core') diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala new file mode 100644 index 0000000000..f937877340 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala @@ -0,0 +1,43 @@ +package spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import spark.TaskContext +import spark.RDD +import spark.SparkContext +import spark.Split + +class TaskContextSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + + after { + if (sc != null) { + sc.stop() + sc = null + } + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + } + + test("Calls executeOnCompleteCallbacks after failure") { + var completed = false + sc = new SparkContext("local", "test") + val rdd = new RDD[String](sc) { + override val splits = Array[Split](StubSplit(0)) + override val dependencies = List() + override def compute(split: Split, context: TaskContext) = { + context.addOnCompleteCallback(() => completed = true) + sys.error("failed") + } + } + val func = (c: TaskContext, i: Iterator[String]) => i.next + val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0) + intercept[RuntimeException] { + task.run(0) + } + assert(completed === true) + } + + case class StubSplit(val index: Int) extends Split +} \ No newline at end of file -- cgit v1.2.3 From 4078623b9f2a338d4992c3dfd3af3a5550615180 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 15 Jan 2013 12:05:54 -0800 Subject: Remove broken attempt to test fetching case. --- .../test/scala/spark/MapOutputTrackerSuite.scala | 48 +--------------------- 1 file changed, 2 insertions(+), 46 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index aa1d8ac7e6..d3dd3a8fa4 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -105,52 +105,8 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - } - - test("simulatenous fetch fails") { - System.clearProperty("spark.master.host") - val dummyActorSystem = ActorSystem("testDummy") - val dummyTracker = new MapOutputTracker(dummyActorSystem, true) - dummyTracker.registerShuffle(10, 1) - // val compressedSize1000 = MapOutputTracker.compressSize(1000L) - // val size100 = MapOutputTracker.decompressSize(compressedSize1000) - // dummyTracker.registerMapOutput(10, 0, new MapStatus( - // new BlockManagerId("hostA", 1000), Array(compressedSize1000))) - val serializedMessage = dummyTracker.getSerializedLocations(10) - val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("test", "localhost", 0) - System.setProperty("spark.master.port", boundPort.toString) - val delayResponseLock = new java.lang.Object - val delayResponseActor = actorSystem.actorOf(Props(new Actor { - override def receive = { - case GetMapOutputStatuses(shuffleId: Int, requester: String) => - delayResponseLock.synchronized { - sender ! serializedMessage - } - } - }), name = "MapOutputTracker") - val slaveTracker = new MapOutputTracker(actorSystem, false) - var firstFailed = false - var secondFailed = false - val firstFetch = new Thread { - override def run() { - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - firstFailed = true - } - } - val secondFetch = new Thread { - override def run() { - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - secondFailed = true - } - } - delayResponseLock.synchronized { - firstFetch.start - secondFetch.start - } - firstFetch.join - secondFetch.join - assert(firstFailed && secondFailed) + // failure should be cached + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } } } -- cgit v1.2.3 From a805ac4a7cdd520b6141dd885c780c526bb54ba6 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Jan 2013 10:55:26 -0800 Subject: Disabled checkpoint for PairwiseRDD (pySpark). --- core/src/main/scala/spark/api/python/PythonRDD.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 276035a9ad..0138b22d38 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -138,6 +138,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } + override def checkpoint() { } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } -- cgit v1.2.3 From eae698f755f41fd8bdff94c498df314ed74aa3c1 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 16 Jan 2013 12:21:37 -0800 Subject: remove unused thread pool --- core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala | 3 --- 1 file changed, 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 915f71ba9f..a29bf974d2 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -24,9 +24,6 @@ private[spark] class StandaloneExecutorBackend( with ExecutorBackend with Logging { - val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) - var master: ActorRef = null override def preStart() { -- cgit v1.2.3 From 742bc841adb2a57b05e7a155681a162ab9dfa2c1 Mon Sep 17 00:00:00 2001 From: Fernand Pajot Date: Thu, 17 Jan 2013 16:56:11 -0800 Subject: changed HttpBroadcast server cache to be in spark.local.dir instead of java.io.tmpdir --- core/src/main/scala/spark/broadcast/HttpBroadcast.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 7eb4ddb74f..96dc28f12a 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -89,7 +89,7 @@ private object HttpBroadcast extends Logging { } private def createServer() { - broadcastDir = Utils.createTempDir() + broadcastDir = Utils.createTempDir(System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) server = new HttpServer(broadcastDir) server.start() serverUri = server.uri -- cgit v1.2.3 From 54c0f9f185576e9b844fa8f81ca410f188daa51c Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 17 Jan 2013 17:40:55 -0800 Subject: Fix code that assumed spark.local.dir is only a single directory --- core/src/main/scala/spark/Utils.scala | 11 ++++++++++- core/src/main/scala/spark/broadcast/HttpBroadcast.scala | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 0e7007459d..aeed5d2f32 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -134,7 +134,7 @@ private object Utils extends Logging { */ def fetchFile(url: String, targetDir: File) { val filename = url.split("/").last - val tempDir = System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")) + val tempDir = getLocalDir val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) val targetFile = new File(targetDir, filename) val uri = new URI(url) @@ -204,6 +204,15 @@ private object Utils extends Logging { FileUtil.chmod(filename, "a+x") } + /** + * Get a temporary directory using Spark's spark.local.dir property, if set. This will always + * return a single directory, even though the spark.local.dir property might be a list of + * multiple paths. + */ + def getLocalDir: String = { + System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) + } + /** * Shuffle the elements of a collection into a random order, returning the * result in a new collection. Unlike scala.util.Random.shuffle, this method diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 96dc28f12a..856a4683a9 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -89,7 +89,7 @@ private object HttpBroadcast extends Logging { } private def createServer() { - broadcastDir = Utils.createTempDir(System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) + broadcastDir = Utils.createTempDir(Utils.getLocalDir) server = new HttpServer(broadcastDir) server.start() serverUri = server.uri -- cgit v1.2.3 From d5570c7968baba1c1fe86c68dc1c388fae23907b Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 16 Jan 2013 21:07:49 -0800 Subject: Adding checkpointing to Java API --- .../main/scala/spark/api/java/JavaRDDLike.scala | 28 ++++++++++++++++++++++ .../scala/spark/api/java/JavaSparkContext.scala | 26 ++++++++++++++++++++ core/src/test/scala/spark/JavaAPISuite.java | 27 +++++++++++++++++++++ 3 files changed, 81 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 81d3a94466..958f5c26a1 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -9,6 +9,7 @@ import spark.api.java.JavaPairRDD._ import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} import spark.partial.{PartialResult, BoundedDouble} import spark.storage.StorageLevel +import com.google.common.base.Optional trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { @@ -298,4 +299,31 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Save this RDD as a SequenceFile of serialized objects. */ def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path) + + /** + * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` + * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. + * This is used to truncate very long lineages. In the current implementation, Spark will save + * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. + * Hence, it is strongly recommended to use checkpoint() on RDDs when + * (i) checkpoint() is called before the any job has been executed on this RDD. + * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will + * require recomputation. + */ + def checkpoint() = rdd.checkpoint() + + /** + * Return whether this RDD has been checkpointed or not + */ + def isCheckpointed(): Boolean = rdd.isCheckpointed() + + /** + * Gets the name of the file to which this RDD was checkpointed + */ + def getCheckpointFile(): Optional[String] = { + rdd.getCheckpointFile match { + case Some(file) => Optional.of(file) + case _ => Optional.absent() + } + } } diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index bf9ad7a200..22bfa2280d 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -342,6 +342,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def clearFiles() { sc.clearFiles() } + + /** + * Set the directory under which RDDs are going to be checkpointed. This method will + * create this directory and will throw an exception of the path already exists (to avoid + * overwriting existing files may be overwritten). The directory will be deleted on exit + * if indicated. + */ + def setCheckpointDir(dir: String, useExisting: Boolean) { + sc.setCheckpointDir(dir, useExisting) + } + + /** + * Set the directory under which RDDs are going to be checkpointed. This method will + * create this directory and will throw an exception of the path already exists (to avoid + * overwriting existing files may be overwritten). The directory will be deleted on exit + * if indicated. + */ + def setCheckpointDir(dir: String) { + sc.setCheckpointDir(dir) + } + + protected def checkpointFile[T](path: String): JavaRDD[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + new JavaRDD(sc.checkpointFile(path)) + } } object JavaSparkContext { diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index b99e790093..0b5354774b 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -625,4 +625,31 @@ public class JavaAPISuite implements Serializable { }); Assert.assertEquals((Float) 25.0f, floatAccum.value()); } + + @Test + public void checkpointAndComputation() { + File tempDir = Files.createTempDir(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + Assert.assertEquals(false, rdd.isCheckpointed()); + rdd.checkpoint(); + rdd.count(); // Forces the DAG to cause a checkpoint + Assert.assertEquals(true, rdd.isCheckpointed()); + Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect()); + } + + @Test + public void checkpointAndRestore() { + File tempDir = Files.createTempDir(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + Assert.assertEquals(false, rdd.isCheckpointed()); + rdd.checkpoint(); + rdd.count(); // Forces the DAG to cause a checkpoint + Assert.assertEquals(true, rdd.isCheckpointed()); + + Assert.assertTrue(rdd.getCheckpointFile().isPresent()); + JavaRDD recovered = sc.checkpointFile(rdd.getCheckpointFile().get()); + Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); + } } -- cgit v1.2.3 From 214345ceace634ec9cc83c4c85b233b699e0d219 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 19 Jan 2013 23:50:17 -0800 Subject: Fixed issue https://spark-project.atlassian.net/browse/STREAMING-29, along with updates to doc comments in SparkContext.checkpoint(). --- core/src/main/scala/spark/RDD.scala | 17 ++++++++--------- core/src/main/scala/spark/RDDCheckpointData.scala | 2 +- core/src/main/scala/spark/SparkContext.scala | 13 +++++++------ streaming/src/main/scala/spark/streaming/DStream.scala | 8 +++++++- 4 files changed, 23 insertions(+), 17 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index a9f2e86455..e0d2eabb1d 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -549,17 +549,16 @@ abstract class RDD[T: ClassManifest]( } /** - * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` - * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. - * This is used to truncate very long lineages. In the current implementation, Spark will save - * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. - * Hence, it is strongly recommended to use checkpoint() on RDDs when - * (i) checkpoint() is called before the any job has been executed on this RDD. - * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will - * require recomputation. + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. */ def checkpoint() { - if (checkpointData.isEmpty) { + if (context.checkpointDir.isEmpty) { + throw new Exception("Checkpoint directory has not been set in the SparkContext") + } else if (checkpointData.isEmpty) { checkpointData = Some(new RDDCheckpointData(this)) checkpointData.get.markForCheckpoint() } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index d845a522e4..18df530b7d 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -63,7 +63,7 @@ extends Logging with Serializable { } // Save to file, and reload it as an RDD - val path = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString + val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _) val newRDD = new CheckpointRDD[T](rdd.context, path) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 88cf357ebf..7f3259d982 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -184,7 +184,7 @@ class SparkContext( private var dagScheduler = new DAGScheduler(taskScheduler) - private[spark] var checkpointDir: String = null + private[spark] var checkpointDir: Option[String] = None // Methods for creating RDDs @@ -595,10 +595,11 @@ class SparkContext( } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists and useExisting is set to true, then the + * exisiting directory will be used. Otherwise an exception will be thrown to + * prevent accidental overriding of checkpoint files in the existing directory. */ def setCheckpointDir(dir: String, useExisting: Boolean = false) { val path = new Path(dir) @@ -610,7 +611,7 @@ class SparkContext( fs.mkdirs(path) } } - checkpointDir = dir + checkpointDir = Some(dir) } /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */ diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index fbe3cebd6d..c4442b6a0c 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -154,10 +154,16 @@ abstract class DStream[T: ClassManifest] ( assert( !mustCheckpoint || checkpointDuration != null, - "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " + + "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set." + " Please use DStream.checkpoint() to set the interval." ) + assert( + checkpointDuration == null || ssc.sc.checkpointDir.isDefined, + "The checkpoint directory has not been set. Please use StreamingContext.checkpoint()" + + " or SparkContext.checkpoint() to set the checkpoint directory." + ) + assert( checkpointDuration == null || checkpointDuration >= slideDuration, "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + -- cgit v1.2.3 From 8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 01:57:44 -0800 Subject: Added accumulators to PySpark --- .../main/scala/spark/api/python/PythonRDD.scala | 83 +++++++++-- python/pyspark/__init__.py | 4 + python/pyspark/accumulators.py | 166 +++++++++++++++++++++ python/pyspark/context.py | 38 +++++ python/pyspark/rdd.py | 2 +- python/pyspark/serializers.py | 7 +- python/pyspark/shell.py | 4 +- python/pyspark/worker.py | 7 +- 8 files changed, 290 insertions(+), 21 deletions(-) create mode 100644 python/pyspark/accumulators.py (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index f431ef28d3..fb13e84658 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -1,7 +1,8 @@ package spark.api.python import java.io._ -import java.util.{List => JList} +import java.net._ +import java.util.{List => JList, ArrayList => JArrayList, Collections} import scala.collection.JavaConversions._ import scala.io.Source @@ -10,25 +11,26 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast import spark._ import spark.rdd.PipedRDD -import java.util private[spark] class PythonRDD[T: ClassManifest]( - parent: RDD[T], - command: Seq[String], - envVars: java.util.Map[String, String], - preservePartitoning: Boolean, - pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) + parent: RDD[T], + command: Seq[String], + envVars: java.util.Map[String, String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) { // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + preservePartitoning: Boolean, pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) = this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, - broadcastVars) + broadcastVars, accumulator) override def splits = parent.splits @@ -93,18 +95,30 @@ private[spark] class PythonRDD[T: ClassManifest]( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(proc.getInputStream) return new Iterator[Array[Byte]] { - def next() = { + def next(): Array[Byte] = { val obj = _nextObj _nextObj = read() obj } - private def read() = { + private def read(): Array[Byte] = { try { val length = stream.readInt() - val obj = new Array[Byte](length) - stream.readFully(obj) - obj + if (length != -1) { + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + } else { + // We've finished the data section of the output, but we can still read some + // accumulator updates; let's do that, breaking when we get EOFException + while (true) { + val len2 = stream.readInt() + val update = new Array[Byte](len2) + stream.readFully(update) + accumulator += Collections.singletonList(update) + } + new Array[Byte](0) + } } catch { case eof: EOFException => { val exitStatus = proc.waitFor() @@ -246,3 +260,40 @@ private class ExtractValue extends spark.api.java.function.Function[(Array[Byte] private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } + +/** + * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it + * collects a list of pickled strings that we pass to Python through a socket. + */ +class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) + extends AccumulatorParam[JList[Array[Byte]]] { + + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList + + override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) + : JList[Array[Byte]] = { + if (serverHost == null) { + // This happens on the worker node, where we just want to remember all the updates + val1.addAll(val2) + val1 + } else { + // This happens on the master, where we pass the updates to Python through a socket + val socket = new Socket(serverHost, serverPort) + val in = socket.getInputStream + val out = new DataOutputStream(socket.getOutputStream) + out.writeInt(val2.size) + for (array <- val2) { + out.writeInt(array.length) + out.write(array) + } + out.flush() + // Wait for a byte from the Python side as an acknowledgement + val byteRead = in.read() + if (byteRead == -1) { + throw new SparkException("EOF reached before Python server acknowledged") + } + socket.close() + null + } + } +} diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index c595ae0842..00666bc0a3 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -7,6 +7,10 @@ Public classes: Main entry point for Spark functionality. - L{RDD} A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + - L{Broadcast} + A broadcast variable that gets reused across tasks. + - L{Accumulator} + An "add-only" shared variable that tasks can only add values to. """ import sys import os diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py new file mode 100644 index 0000000000..438af4cfc0 --- /dev/null +++ b/python/pyspark/accumulators.py @@ -0,0 +1,166 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> a = sc.accumulator(1) +>>> a.value +1 +>>> a.value = 2 +>>> a.value +2 +>>> a += 5 +>>> a.value +7 + +>>> rdd = sc.parallelize([1,2,3]) +>>> def f(x): +... global a +... a += x +>>> rdd.foreach(f) +>>> a.value +13 + +>>> class VectorAccumulatorParam(object): +... def zero(self, value): +... return [0.0] * len(value) +... def addInPlace(self, val1, val2): +... for i in xrange(len(val1)): +... val1[i] += val2[i] +... return val1 +>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) +>>> va.value +[1.0, 2.0, 3.0] +>>> def g(x): +... global va +... va += [x] * 3 +>>> rdd.foreach(g) +>>> va.value +[7.0, 8.0, 9.0] + +>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> def h(x): +... global a +... a.value = 7 +>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Exception:... +""" + +import struct +import SocketServer +import threading +from pyspark.cloudpickle import CloudPickler +from pyspark.serializers import read_int, read_with_length, load_pickle + + +# Holds accumulators registered on the current machine, keyed by ID. This is then used to send +# the local accumulator updates back to the driver program at the end of a task. +_accumulatorRegistry = {} + + +def _deserialize_accumulator(aid, zero_value, accum_param): + from pyspark.accumulators import _accumulatorRegistry + accum = Accumulator(aid, zero_value, accum_param) + accum._deserialized = True + _accumulatorRegistry[aid] = accum + return accum + + +class Accumulator(object): + def __init__(self, aid, value, accum_param): + """Create a new Accumulator with a given initial value and AccumulatorParam object""" + from pyspark.accumulators import _accumulatorRegistry + self.aid = aid + self.accum_param = accum_param + self._value = value + self._deserialized = False + _accumulatorRegistry[aid] = self + + def __reduce__(self): + """Custom serialization; saves the zero value from our AccumulatorParam""" + param = self.accum_param + return (_deserialize_accumulator, (self.aid, param.zero(self._value), param)) + + @property + def value(self): + """Get the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + return self._value + + @value.setter + def value(self, value): + """Sets the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + self._value = value + + def __iadd__(self, term): + """The += operator; adds a term to this accumulator's value""" + self._value = self.accum_param.addInPlace(self._value, term) + return self + + def __str__(self): + return str(self._value) + + +class AddingAccumulatorParam(object): + """ + An AccumulatorParam that uses the + operators to add values. Designed for simple types + such as integers, floats, and lists. Requires the zero value for the underlying type + as a parameter. + """ + + def __init__(self, zero_value): + self.zero_value = zero_value + + def zero(self, value): + return self.zero_value + + def addInPlace(self, value1, value2): + value1 += value2 + return value1 + + +# Singleton accumulator params for some standard types +INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) +DOUBLE_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) +COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) + + +class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + def handle(self): + from pyspark.accumulators import _accumulatorRegistry + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = load_pickle(read_with_length(self.rfile)) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + + +def _start_update_server(): + """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" + server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + return server + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e486f206b0..1e2f845f9c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -2,6 +2,8 @@ import os import atexit from tempfile import NamedTemporaryFile +from pyspark import accumulators +from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched @@ -22,6 +24,7 @@ class SparkContext(object): _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _takePartition = jvm.PythonRDD.takePartition + _next_accum_id = 0 def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -52,6 +55,14 @@ class SparkContext(object): self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, empty_string_array) + # Create a single Accumulator in Java that we'll send all our updates through; + # they will be passed back to us through a TCP server + self._accumulatorServer = accumulators._start_update_server() + (host, port) = self._accumulatorServer.server_address + self._javaAccumulator = self._jsc.accumulator( + self.jvm.java.util.ArrayList(), + self.jvm.PythonAccumulatorParam(host, port)) + self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have @@ -74,6 +85,8 @@ class SparkContext(object): def __del__(self): if self._jsc: self._jsc.stop() + if self._accumulatorServer: + self._accumulatorServer.shutdown() def stop(self): """ @@ -129,6 +142,31 @@ class SparkContext(object): return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + def accumulator(self, value, accum_param=None): + """ + Create an C{Accumulator} with the given initial value, using a given + AccumulatorParam helper object to define how to add values of the data + type if provided. Default AccumulatorParams are used for integers and + floating-point numbers if you do not provide one. For other types, the + AccumulatorParam must implement two methods: + - C{zero(value)}: provide a "zero value" for the type, compatible in + dimensions with the provided C{value} (e.g., a zero vector). + - C{addInPlace(val1, val2)}: add two values of the accumulator's data + type, returning a new value; for efficiency, can also update C{val1} + in place and return it. + """ + if accum_param == None: + if isinstance(value, int): + accum_param = accumulators.INT_ACCUMULATOR_PARAM + elif isinstance(value, float): + accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM + elif isinstance(value, complex): + accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM + else: + raise Exception("No default accumulator param for type %s" % type(value)) + SparkContext._next_accum_id += 1 + return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) + def addFile(self, path): """ Add a file to be downloaded into the working directory of this Spark diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1d36da42b0..d705f0f9e1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -703,7 +703,7 @@ class PipelinedRDD(RDD): env = MapConverter().convert(env, self.ctx.gateway._gateway_client) python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, class_manifest) + broadcast_vars, self.ctx._javaAccumulator, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 9a5151ea00..115cf28cc2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -52,8 +52,13 @@ def read_int(stream): raise EOFError return struct.unpack("!i", length)[0] + +def write_int(value, stream): + stream.write(struct.pack("!i", value)) + + def write_with_length(obj, stream): - stream.write(struct.pack("!i", len(obj))) + write_int(len(obj), stream) stream.write(obj) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 7e6ad3aa76..f6328c561f 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -1,7 +1,7 @@ """ An interactive shell. -This fle is designed to be launched as a PYTHONSTARTUP script. +This file is designed to be launched as a PYTHONSTARTUP script. """ import os from pyspark.context import SparkContext @@ -14,4 +14,4 @@ print "Spark context avaiable as sc." # which allows us to execute the user's PYTHONSTARTUP file: _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') if _pythonstartup and os.path.isfile(_pythonstartup): - execfile(_pythonstartup) + execfile(_pythonstartup) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3d792bbaa2..b2b9288089 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -5,9 +5,10 @@ import sys from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. +from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import write_with_length, read_with_length, \ +from pyspark.serializers import write_with_length, read_with_length, write_int, \ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file @@ -36,6 +37,10 @@ def main(): iterator = read_from_pickle_file(sys.stdin) for obj in func(split_index, iterator): write_with_length(dumps(obj), old_stdout) + # Mark the beginning of the accumulators section of the output + write_int(-1, old_stdout) + for aid, accum in _accumulatorRegistry.items(): + write_with_length(dump_pickle((aid, accum._value)), old_stdout) if __name__ == '__main__': -- cgit v1.2.3 From 7ed1bf4b485131d58ea6728e7247b79320aca9e6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 16 Jan 2013 19:15:14 -0800 Subject: Add RDD checkpointing to Python API. --- .../main/scala/spark/api/python/PythonRDD.scala | 3 -- python/epydoc.conf | 2 +- python/pyspark/context.py | 9 +++++ python/pyspark/rdd.py | 34 ++++++++++++++++ python/pyspark/tests.py | 46 ++++++++++++++++++++++ python/run-tests | 3 ++ 6 files changed, 93 insertions(+), 4 deletions(-) create mode 100644 python/pyspark/tests.py (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 89f7c316dc..8c38262dd8 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -135,8 +135,6 @@ private[spark] class PythonRDD[T: ClassManifest]( } } - override def checkpoint() { } - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } @@ -152,7 +150,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } - override def checkpoint() { } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } diff --git a/python/epydoc.conf b/python/epydoc.conf index 91ac984ba2..45102cd9fe 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -16,4 +16,4 @@ target: docs/ private: no exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers - pyspark.java_gateway pyspark.examples pyspark.shell + pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1e2f845f9c..a438b43fdc 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -195,3 +195,12 @@ class SparkContext(object): filename = path.split("/")[-1] os.environ["PYTHONPATH"] = \ "%s:%s" % (filename, os.environ["PYTHONPATH"]) + + def setCheckpointDir(self, dirName, useExisting=False): + """ + Set the directory under which RDDs are going to be checkpointed. This + method will create this directory and will throw an exception of the + path already exists (to avoid overwriting existing files may be + overwritten). The directory will be deleted on exit if indicated. + """ + self._jsc.sc().setCheckpointDir(dirName, useExisting) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d705f0f9e1..9b676cae4a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -49,6 +49,40 @@ class RDD(object): self._jrdd.cache() return self + def checkpoint(self): + """ + Mark this RDD for checkpointing. The RDD will be saved to a file inside + `checkpointDir` (set using setCheckpointDir()) and all references to + its parent RDDs will be removed. This is used to truncate very long + lineages. In the current implementation, Spark will save this RDD to + a file (using saveAsObjectFile()) after the first job using this RDD is + done. Hence, it is strongly recommended to use checkpoint() on RDDs + when + + (i) checkpoint() is called before the any job has been executed on this + RDD. + + (ii) This RDD has been made to persist in memory. Otherwise saving it + on a file will require recomputation. + """ + self._jrdd.rdd().checkpoint() + + def isCheckpointed(self): + """ + Return whether this RDD has been checkpointed or not + """ + return self._jrdd.rdd().isCheckpointed() + + def getCheckpointFile(self): + """ + Gets the name of the file to which this RDD was checkpointed + """ + checkpointFile = self._jrdd.rdd().getCheckpointFile() + if checkpointFile.isDefined(): + return checkpointFile.get() + else: + return None + # TODO persist(self, storageLevel) def map(self, f, preservesPartitioning=False): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py new file mode 100644 index 0000000000..c959d5dec7 --- /dev/null +++ b/python/pyspark/tests.py @@ -0,0 +1,46 @@ +""" +Unit tests for PySpark; additional tests are implemented as doctests in +individual modules. +""" +import atexit +import os +import shutil +from tempfile import NamedTemporaryFile +import time +import unittest + +from pyspark.context import SparkContext + + +class TestCheckpoint(unittest.TestCase): + + def setUp(self): + self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) + + def tearDown(self): + self.sc.stop() + + def test_basic_checkpointing(self): + checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(checkpointDir.name) + self.sc.setCheckpointDir(checkpointDir.name) + + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertIsNone(flatMappedRDD.getCheckpointFile()) + + flatMappedRDD.checkpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + self.assertEqual(checkpointDir.name, + os.path.dirname(flatMappedRDD.getCheckpointFile())) + + atexit.register(lambda: shutil.rmtree(checkpointDir.name)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/run-tests b/python/run-tests index 32470911f9..ce214e98a8 100755 --- a/python/run-tests +++ b/python/run-tests @@ -14,6 +14,9 @@ FAILED=$(($?||$FAILED)) $FWDIR/pyspark -m doctest pyspark/accumulators.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark -m unittest pyspark.tests +FAILED=$(($?||$FAILED)) + if [[ $FAILED != 0 ]]; then echo -en "\033[31m" # Red echo "Had test failures; see logs." -- cgit v1.2.3 From 5b6ea9e9a04994553d0319c541ca356e2e3064a7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 20 Jan 2013 15:31:41 -0800 Subject: Update checkpointing API docs in Python/Java. --- core/src/main/scala/spark/api/java/JavaRDDLike.scala | 17 +++++++---------- .../main/scala/spark/api/java/JavaSparkContext.scala | 17 +++++++++-------- python/pyspark/context.py | 11 +++++++---- python/pyspark/rdd.py | 17 +++++------------ 4 files changed, 28 insertions(+), 34 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 087270e46d..b3698ffa44 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -307,16 +307,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] JavaPairRDD.fromRDD(rdd.keyBy(f)) } - - /** - * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` - * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. - * This is used to truncate very long lineages. In the current implementation, Spark will save - * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. - * Hence, it is strongly recommended to use checkpoint() on RDDs when - * (i) checkpoint() is called before the any job has been executed on this RDD. - * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will - * require recomputation. + + /** + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. */ def checkpoint() = rdd.checkpoint() diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index fa2f14113d..14699961ad 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -357,20 +357,21 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists and useExisting is set to true, then the + * exisiting directory will be used. Otherwise an exception will be thrown to + * prevent accidental overriding of checkpoint files in the existing directory. */ def setCheckpointDir(dir: String, useExisting: Boolean) { sc.setCheckpointDir(dir, useExisting) } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists, an exception will be thrown to prevent accidental + * overriding of checkpoint files. */ def setCheckpointDir(dir: String) { sc.setCheckpointDir(dir) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 8beb8e2ae9..dcbed37270 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -202,9 +202,12 @@ class SparkContext(object): def setCheckpointDir(self, dirName, useExisting=False): """ - Set the directory under which RDDs are going to be checkpointed. This - method will create this directory and will throw an exception of the - path already exists (to avoid overwriting existing files may be - overwritten). The directory will be deleted on exit if indicated. + Set the directory under which RDDs are going to be checkpointed. The + directory must be a HDFS path if running on a cluster. + + If the directory does not exist, it will be created. If the directory + exists and C{useExisting} is set to true, then the exisiting directory + will be used. Otherwise an exception will be thrown to prevent + accidental overriding of checkpoint files in the existing directory. """ self._jsc.sc().setCheckpointDir(dirName, useExisting) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2a2ff9b271..7b6ab956ee 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -52,18 +52,11 @@ class RDD(object): def checkpoint(self): """ - Mark this RDD for checkpointing. The RDD will be saved to a file inside - `checkpointDir` (set using setCheckpointDir()) and all references to - its parent RDDs will be removed. This is used to truncate very long - lineages. In the current implementation, Spark will save this RDD to - a file (using saveAsObjectFile()) after the first job using this RDD is - done. Hence, it is strongly recommended to use checkpoint() on RDDs - when - - (i) checkpoint() is called before the any job has been executed on this - RDD. - - (ii) This RDD has been made to persist in memory. Otherwise saving it + Mark this RDD for checkpointing. It will be saved to a file inside the + checkpoint directory set with L{SparkContext.setCheckpointDir()} and + all references to its parent RDDs will be removed. This function must + be called before any job has been executed on this RDD. It is strongly + recommended that this RDD is persisted in memory, otherwise saving it on a file will require recomputation. """ self.is_checkpointed = True -- cgit v1.2.3 From 9f211dd3f0132daf72fb39883fa4b28e4fd547ca Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 14 Jan 2013 15:30:42 -0800 Subject: Fix PythonPartitioner equality; see SPARK-654. PythonPartitioner did not take the Python-side partitioning function into account when checking for equality, which might cause problems in the future. --- .../main/scala/spark/api/python/PythonPartitioner.scala | 13 +++++++++++-- core/src/main/scala/spark/api/python/PythonRDD.scala | 5 ----- python/pyspark/rdd.py | 17 +++++++++++------ 3 files changed, 22 insertions(+), 13 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala index 648d9402b0..519e310323 100644 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -6,8 +6,17 @@ import java.util.Arrays /** * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. + * + * Stores the unique id() of the Python-side partitioning function so that it is incorporated into + * equality comparisons. Correctness requires that the id is a unique identifier for the + * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning + * function). This can be ensured by using the Python id() function and maintaining a reference + * to the Python partitioning function so that its id() is not reused. */ -private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner { +private[spark] class PythonPartitioner( + override val numPartitions: Int, + val pyPartitionFunctionId: Long) + extends Partitioner { override def getPartition(key: Any): Int = { if (key == null) { @@ -32,7 +41,7 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends override def equals(other: Any): Boolean = other match { case h: PythonPartitioner => - h.numPartitions == numPartitions + h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId case _ => false } diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 89f7c316dc..e4c0530241 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -252,11 +252,6 @@ private object Pickle { val APPENDS: Byte = 'e' } -private class ExtractValue extends spark.api.java.function.Function[(Array[Byte], - Array[Byte]), Array[Byte]] { - override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2 -} - private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d705f0f9e1..b58bf24e3e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -33,6 +33,7 @@ class RDD(object): self._jrdd = jrdd self.is_cached = False self.ctx = ctx + self._partitionFunc = None @property def context(self): @@ -497,7 +498,7 @@ class RDD(object): return python_right_outer_join(self, other, numSplits) # TODO: add option to control map-side combining - def partitionBy(self, numSplits, hashFunc=hash): + def partitionBy(self, numSplits, partitionFunc=hash): """ Return a copy of the RDD partitioned using the specified partitioner. @@ -514,17 +515,21 @@ class RDD(object): def add_shuffle_key(split, iterator): buckets = defaultdict(list) for (k, v) in iterator: - buckets[hashFunc(k) % numSplits].append((k, v)) + buckets[partitionFunc(k) % numSplits].append((k, v)) for (split, items) in buckets.iteritems(): yield str(split) yield dump_pickle(Batch(items)) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) - jrdd = pairRDD.partitionBy(partitioner) - jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - return RDD(jrdd, self.ctx) + partitioner = self.ctx.jvm.PythonPartitioner(numSplits, + id(partitionFunc)) + jrdd = pairRDD.partitionBy(partitioner).values() + rdd = RDD(jrdd, self.ctx) + # This is required so that id(partitionFunc) remains unique, even if + # partitionFunc is a lambda: + rdd._partitionFunc = partitionFunc + return rdd # TODO: add control over map-side aggregation def combineByKey(self, createCombiner, mergeValue, mergeCombiners, -- cgit v1.2.3 From c0b9ceb8c3d56c6d6f6f6b5925c87abad06be646 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 00:23:53 -0800 Subject: Log remote lifecycle events in Akka for easier debugging --- core/src/main/scala/spark/util/AkkaUtils.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index e67cb0336d..fbd0ff46bf 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -32,6 +32,7 @@ private[spark] object AkkaUtils { akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] akka.actor.provider = "akka.remote.RemoteActorRefProvider" akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" + akka.remote.log-remote-lifecycle-events = on akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = %ds -- cgit v1.2.3 From 69a417858bf1627de5220d41afba64853d4bf64d Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 12:42:11 -0600 Subject: Also use hadoopConfiguration in newAPI methods. --- core/src/main/scala/spark/PairRDDFunctions.scala | 4 ++-- core/src/main/scala/spark/SparkContext.scala | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 51c15837c4..1c18736805 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -494,7 +494,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { - saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration) + saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass) } /** @@ -506,7 +506,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration) { + conf: Configuration = self.context.hadoopConfiguration) { val job = new NewAPIHadoopJob(conf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index f6b98c41bc..303e5081a4 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -293,8 +293,7 @@ class SparkContext( path, fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], - vm.erasure.asInstanceOf[Class[V]], - new Configuration(hadoopConfiguration)) + vm.erasure.asInstanceOf[Class[V]]) } /** @@ -306,7 +305,7 @@ class SparkContext( fClass: Class[F], kClass: Class[K], vClass: Class[V], - conf: Configuration): RDD[(K, V)] = { + conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration @@ -318,7 +317,7 @@ class SparkContext( * and extra configuration options to pass to the input format. */ def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( - conf: Configuration, + conf: Configuration = hadoopConfiguration, fClass: Class[F], kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = { -- cgit v1.2.3 From f116d6b5c6029c2f96160bd84829a6fe8b73cccf Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 18 Jan 2013 13:24:37 -0800 Subject: executor can use a different sparkHome from Worker --- core/src/main/scala/spark/deploy/DeployMessage.scala | 4 +++- core/src/main/scala/spark/deploy/JobDescription.scala | 5 ++++- core/src/main/scala/spark/deploy/client/TestClient.scala | 3 ++- core/src/main/scala/spark/deploy/master/Master.scala | 9 +++++---- core/src/main/scala/spark/deploy/worker/Worker.scala | 4 ++-- .../spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 3 ++- 6 files changed, 18 insertions(+), 10 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 457122745b..7ee3e63429 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -5,6 +5,7 @@ import spark.deploy.master.{WorkerInfo, JobInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List import scala.collection.mutable.HashMap +import java.io.File private[spark] sealed trait DeployMessage extends Serializable @@ -42,7 +43,8 @@ private[spark] case class LaunchExecutor( execId: Int, jobDesc: JobDescription, cores: Int, - memory: Int) + memory: Int, + sparkHome: File) extends DeployMessage diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala index 20879c5f11..7f8f9af417 100644 --- a/core/src/main/scala/spark/deploy/JobDescription.scala +++ b/core/src/main/scala/spark/deploy/JobDescription.scala @@ -1,10 +1,13 @@ package spark.deploy +import java.io.File + private[spark] class JobDescription( val name: String, val cores: Int, val memoryPerSlave: Int, - val command: Command) + val command: Command, + val sparkHome: File) extends Serializable { val user = System.getProperty("user.name", "") diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index 57a7e123b7..dc743b1fbf 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -3,6 +3,7 @@ package spark.deploy.client import spark.util.AkkaUtils import spark.{Logging, Utils} import spark.deploy.{Command, JobDescription} +import java.io.File private[spark] object TestClient { @@ -25,7 +26,7 @@ private[spark] object TestClient { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val desc = new JobDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map())) + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), new File("dummy-spark-home")) val listener = new TestListener val client = new Client(actorSystem, url, desc, listener) client.start() diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 6ecebe626a..f0bee67159 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -6,6 +6,7 @@ import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, Remote import java.text.SimpleDateFormat import java.util.Date +import java.io.File import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -173,7 +174,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor for (pos <- 0 until numUsable) { if (assigned(pos) > 0) { val exec = job.addExecutor(usableWorkers(pos), assigned(pos)) - launchExecutor(usableWorkers(pos), exec) + launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome) job.state = JobState.RUNNING } } @@ -186,7 +187,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor val coresToUse = math.min(worker.coresFree, job.coresLeft) if (coresToUse > 0) { val exec = job.addExecutor(worker, coresToUse) - launchExecutor(worker, exec) + launchExecutor(worker, exec, job.desc.sparkHome) job.state = JobState.RUNNING } } @@ -195,10 +196,10 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } } - def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) { + def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: File) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory) + worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 7c9e588ea2..078b2d8037 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -119,10 +119,10 @@ private[spark] class Worker( logError("Worker registration failed: " + message) System.exit(1) - case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_) => + case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name)) val manager = new ExecutorRunner( - jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, sparkHome, workDir) + jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, execSparkHome_, workDir) executors(jobId + "/" + execId) = manager manager.start() coresUsed += cores_ diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index e2301347e5..0dcc2efaca 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -4,6 +4,7 @@ import spark.{Utils, Logging, SparkContext} import spark.deploy.client.{Client, ClientListener} import spark.deploy.{Command, JobDescription} import scala.collection.mutable.HashMap +import java.io.File private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, @@ -39,7 +40,7 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.sparkHome)) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() -- cgit v1.2.3 From aae5a920a4db0c31918a65a03ce7d2087826fd65 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 18 Jan 2013 13:28:50 -0800 Subject: get sparkHome the correct way --- .../scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 0dcc2efaca..08b9d6ff47 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -40,7 +40,7 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.sparkHome)) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.getSparkHome())) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() -- cgit v1.2.3 From 5bf73df7f08b17719711a5f05f0b3390b4951272 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sat, 19 Jan 2013 13:26:15 -0800 Subject: oops, fix stupid compile error --- .../scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 08b9d6ff47..94886d3941 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -40,7 +40,8 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.getSparkHome())) + val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sparkHome)) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() -- cgit v1.2.3 From c73107500e0a5b6c5f0b4aba8c4504ee4c2adbaf Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sun, 20 Jan 2013 21:55:50 -0800 Subject: send sparkHome as String instead of File over network --- core/src/main/scala/spark/deploy/DeployMessage.scala | 2 +- core/src/main/scala/spark/deploy/master/Master.scala | 2 +- core/src/main/scala/spark/deploy/worker/Worker.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 7ee3e63429..a4081ef89c 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -44,7 +44,7 @@ private[spark] case class LaunchExecutor( jobDesc: JobDescription, cores: Int, memory: Int, - sparkHome: File) + sparkHome: String) extends DeployMessage diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index f0bee67159..1b6f808a51 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -199,7 +199,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: File) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) + worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome.getAbsolutePath) exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 078b2d8037..19bf2be118 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -122,7 +122,7 @@ private[spark] class Worker( case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name)) val manager = new ExecutorRunner( - jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, execSparkHome_, workDir) + jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir) executors(jobId + "/" + execId) = manager manager.start() coresUsed += cores_ -- cgit v1.2.3 From fe26acc482f358bf87700f5e80160f7ce558cea7 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sun, 20 Jan 2013 21:57:44 -0800 Subject: remove unused imports --- core/src/main/scala/spark/deploy/DeployMessage.scala | 2 -- 1 file changed, 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index a4081ef89c..35f40c6e91 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -4,8 +4,6 @@ import spark.deploy.ExecutorState.ExecutorState import spark.deploy.master.{WorkerInfo, JobInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List -import scala.collection.mutable.HashMap -import java.io.File private[spark] sealed trait DeployMessage extends Serializable -- cgit v1.2.3 From a3f571b539ffd126e9f3bc3e9c7bedfcb6f4d2d4 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 21 Jan 2013 10:52:17 -0800 Subject: more File -> String changes --- core/src/main/scala/spark/deploy/JobDescription.scala | 4 +--- core/src/main/scala/spark/deploy/client/TestClient.scala | 3 +-- core/src/main/scala/spark/deploy/master/Master.scala | 5 ++--- .../scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 4 +--- 4 files changed, 5 insertions(+), 11 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala index 7f8f9af417..7160fc05fc 100644 --- a/core/src/main/scala/spark/deploy/JobDescription.scala +++ b/core/src/main/scala/spark/deploy/JobDescription.scala @@ -1,13 +1,11 @@ package spark.deploy -import java.io.File - private[spark] class JobDescription( val name: String, val cores: Int, val memoryPerSlave: Int, val command: Command, - val sparkHome: File) + val sparkHome: String) extends Serializable { val user = System.getProperty("user.name", "") diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index dc743b1fbf..8764c400e2 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -3,7 +3,6 @@ package spark.deploy.client import spark.util.AkkaUtils import spark.{Logging, Utils} import spark.deploy.{Command, JobDescription} -import java.io.File private[spark] object TestClient { @@ -26,7 +25,7 @@ private[spark] object TestClient { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val desc = new JobDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), new File("dummy-spark-home")) + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home") val listener = new TestListener val client = new Client(actorSystem, url, desc, listener) client.start() diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 1b6f808a51..2c2cd0231b 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -6,7 +6,6 @@ import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, Remote import java.text.SimpleDateFormat import java.util.Date -import java.io.File import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -196,10 +195,10 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } } - def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: File) { + def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome.getAbsolutePath) + worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 94886d3941..a21a5b2f3d 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -3,8 +3,6 @@ package spark.scheduler.cluster import spark.{Utils, Logging, SparkContext} import spark.deploy.client.{Client, ClientListener} import spark.deploy.{Command, JobDescription} -import scala.collection.mutable.HashMap -import java.io.File private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, @@ -41,7 +39,7 @@ private[spark] class SparkDeploySchedulerBackend( val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sparkHome)) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() -- cgit v1.2.3 From 4d34c7fc3ecd7a4d035005f84c01e6990c0c345e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 11:33:48 -0800 Subject: Fix compile error caused by cherry-pick --- .../main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index a21a5b2f3d..4f82cd96dd 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -3,6 +3,7 @@ package spark.scheduler.cluster import spark.{Utils, Logging, SparkContext} import spark.deploy.client.{Client, ClientListener} import spark.deploy.{Command, JobDescription} +import scala.collection.mutable.HashMap private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, -- cgit v1.2.3 From a88b44ed3b670633549049e9ccf990ea455e9720 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 11:59:21 -0800 Subject: Only bind to IPv4 addresses when trying to auto-detect external IP --- core/src/main/scala/spark/Utils.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index b3421df27c..692a3f4050 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,7 +1,7 @@ package spark import java.io._ -import java.net.{NetworkInterface, InetAddress, URL, URI} +import java.net.{NetworkInterface, InetAddress, Inet4Address, URL, URI} import java.util.{Locale, Random, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration @@ -251,7 +251,8 @@ private object Utils extends Logging { // Address resolves to something like 127.0.1.1, which happens on Debian; try to find // a better address using the local network interfaces for (ni <- NetworkInterface.getNetworkInterfaces) { - for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress) { + for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && + !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { // We've found an address that looks reasonable! logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + -- cgit v1.2.3 From ffd1623595cdce4080ad1e4e676e65898ebdd6dd Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 15:55:46 -0600 Subject: Minor cleanup. --- core/src/main/scala/spark/Accumulators.scala | 3 +-- core/src/main/scala/spark/Logging.scala | 3 +-- core/src/main/scala/spark/ParallelCollection.scala | 15 +++++---------- core/src/main/scala/spark/TaskContext.scala | 3 +-- core/src/main/scala/spark/rdd/BlockRDD.scala | 6 ++---- core/src/main/scala/spark/rdd/CartesianRDD.scala | 3 +-- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 6 ++---- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 6 ++---- core/src/main/scala/spark/rdd/SampledRDD.scala | 5 ++--- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 3 +-- core/src/main/scala/spark/rdd/UnionRDD.scala | 3 +-- core/src/main/scala/spark/rdd/ZippedRDD.scala | 3 +-- .../scala/spark/scheduler/local/LocalScheduler.scala | 4 ++-- .../scheduler/mesos/CoarseMesosSchedulerBackend.scala | 16 ++++++---------- .../spark/scheduler/mesos/MesosSchedulerBackend.scala | 10 +++------- core/src/test/scala/spark/FileServerSuite.scala | 4 ++-- 16 files changed, 33 insertions(+), 60 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index b644aba5f8..57c6df35be 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -25,8 +25,7 @@ class Accumulable[R, T] ( extends Serializable { val id = Accumulators.newId - @transient - private var value_ = initialValue // Current value on master + @transient private var value_ = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers var deserialized = false diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala index 90bae26202..7c1c1bb144 100644 --- a/core/src/main/scala/spark/Logging.scala +++ b/core/src/main/scala/spark/Logging.scala @@ -11,8 +11,7 @@ import org.slf4j.LoggerFactory trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine - @transient - private var log_ : Logger = null + @transient private var log_ : Logger = null // Method to get or create the logger for this object protected def log: Logger = { diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index ede933c9e9..ad23e5bec8 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -23,32 +23,28 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - @transient sc : SparkContext, + @transient sc: SparkContext, @transient data: Seq[T], numSlices: Int, - locationPrefs : Map[Int,Seq[String]]) + locationPrefs: Map[Int,Seq[String]]) extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split // instead. // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val slices = ParallelCollection.slice(data, numSlices).toArray slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray } - override def getSplits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_ override def compute(s: Split, context: TaskContext) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def getPreferredLocations(s: Split): Seq[String] = { - locationPrefs.get(s.index) match { - case Some(s) => s - case _ => Nil - } + locationPrefs.get(s.index) getOrElse Nil } override def clearDependencies() { @@ -56,7 +52,6 @@ private[spark] class ParallelCollection[T: ClassManifest]( } } - private object ParallelCollection { /** * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala index d2746b26b3..eab85f85a2 100644 --- a/core/src/main/scala/spark/TaskContext.scala +++ b/core/src/main/scala/spark/TaskContext.scala @@ -5,8 +5,7 @@ import scala.collection.mutable.ArrayBuffer class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable { - @transient - val onCompleteCallbacks = new ArrayBuffer[() => Unit] + @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit] // Add a callback function to be executed on task completion. An example use // is for HadoopRDD to register a callback to close the input stream. diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index b1095a52b4..2c022f88e0 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -11,13 +11,11 @@ private[spark] class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) extends RDD[T](sc, Nil) { - @transient - var splits_ : Array[Split] = (0 until blockIds.size).map(i => { + @transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => { new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] }).toArray - @transient - lazy val locations_ = { + @transient lazy val locations_ = { val blockManager = SparkEnv.get.blockManager /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ val locations = blockManager.getLocations(blockIds) diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 79e7c24e7c..453d410ad4 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -35,8 +35,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( val numSplitsInRdd2 = rdd2.splits.size - @transient - var splits_ = { + @transient var splits_ = { // create the cross product split val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 1d528be2aa..8fafd27bb6 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -45,8 +45,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) val aggr = new CoGroupAggregator - @transient - var deps_ = { + @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { if (rdd.partitioner == Some(part)) { @@ -63,8 +62,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) override def getDependencies = deps_ - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val array = new Array[Split](part.numPartitions) for (i <- 0 until array.size) { array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) => diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index bb22db073c..c3b155fcbd 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -37,11 +37,9 @@ class NewHadoopRDD[K, V]( formatter.format(new Date()) } - @transient - private val jobId = new JobID(jobtrackerId, id) + @transient private val jobId = new JobID(jobtrackerId, id) - @transient - private val splits_ : Array[Split] = { + @transient private val splits_ : Array[Split] = { val inputFormat = inputFormatClass.newInstance val jobContext = newJobContext(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 1bc9c96112..e24ad23b21 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -19,13 +19,12 @@ class SampledRDD[T: ClassManifest]( seed: Int) extends RDD[T](prev) { - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val rg = new Random(seed) firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } - override def getSplits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_ override def getPreferredLocations(split: Split) = firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 1b219473e0..28ff19876d 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -22,8 +22,7 @@ class ShuffledRDD[K, V]( override val partitioner = Some(part) - @transient - var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) + @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) override def getSplits = splits_ diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 24a085df02..82f0a44ecd 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -28,8 +28,7 @@ class UnionRDD[T: ClassManifest]( @transient var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index 16e6cc0f1b..d950b06c85 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -34,8 +34,7 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( // TODO: FIX THIS. - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { if (rdd1.splits.size != rdd2.splits.size) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index dff550036d..21d255debd 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -19,8 +19,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon extends TaskScheduler with Logging { - var attemptId = new AtomicInteger(0) - var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) + val attemptId = new AtomicInteger(0) + val threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) val env = SparkEnv.get var listener: TaskSchedulerListener = null diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index c45c7df69c..014906b028 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -64,13 +64,9 @@ private[spark] class CoarseMesosSchedulerBackend( val taskIdToSlaveId = new HashMap[Int, String] val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed - val sparkHome = sc.getSparkHome() match { - case Some(path) => - path - case None => - throw new SparkException("Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor") - } + val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( + "Spark home is not set; set it through the spark.home system " + + "property, the SPARK_HOME environment variable or the SparkContext constructor")) val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt @@ -184,7 +180,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Helper function to pull out a resource from a Mesos Resources protobuf */ - def getResource(res: JList[Resource], name: String): Double = { + private def getResource(res: JList[Resource], name: String): Double = { for (r <- res if r.getName == name) { return r.getScalar.getValue } @@ -193,7 +189,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Build a Mesos resource protobuf object */ - def createResource(resourceName: String, quantity: Double): Protos.Resource = { + private def createResource(resourceName: String, quantity: Double): Protos.Resource = { Resource.newBuilder() .setName(resourceName) .setType(Value.Type.SCALAR) @@ -202,7 +198,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Check whether a Mesos task state represents a finished task */ - def isFinished(state: MesosTaskState) = { + private def isFinished(state: MesosTaskState) = { state == MesosTaskState.TASK_FINISHED || state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_KILLED || diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index 8c7a1dfbc0..2989e31f5e 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -76,13 +76,9 @@ private[spark] class MesosSchedulerBackend( } def createExecutorInfo(): ExecutorInfo = { - val sparkHome = sc.getSparkHome() match { - case Some(path) => - path - case None => - throw new SparkException("Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor") - } + val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( + "Spark home is not set; set it through the spark.home system " + + "property, the SPARK_HOME environment variable or the SparkContext constructor")) val execScript = new File(sparkHome, "spark-executor").getCanonicalPath val environment = Environment.newBuilder() sc.executorEnvs.foreach { case (key, value) => diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala index b4283d9604..fe964bd893 100644 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -9,8 +9,8 @@ import SparkContext._ class FileServerSuite extends FunSuite with BeforeAndAfter { @transient var sc: SparkContext = _ - @transient var tmpFile : File = _ - @transient var testJarFile : File = _ + @transient var tmpFile: File = _ + @transient var testJarFile: File = _ before { // Create a sample text file -- cgit v1.2.3 From ef711902c1f42db14c8ddd524195f0a9efb56e65 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Jan 2013 16:42:24 -0800 Subject: Don't download files to master's working directory. This should avoid exceptions caused by existing files with different contents. I also removed some unused code. --- core/src/main/scala/spark/HttpFileServer.scala | 8 ++--- core/src/main/scala/spark/SparkContext.scala | 7 ++-- core/src/main/scala/spark/SparkEnv.scala | 20 +++++++---- core/src/main/scala/spark/SparkFiles.java | 25 ++++++++++++++ core/src/main/scala/spark/Utils.scala | 16 +-------- .../scala/spark/api/java/JavaSparkContext.scala | 5 +-- .../main/scala/spark/api/python/PythonRDD.scala | 2 ++ .../scala/spark/deploy/worker/ExecutorRunner.scala | 5 --- core/src/main/scala/spark/executor/Executor.scala | 6 ++-- .../spark/scheduler/local/LocalScheduler.scala | 6 ++-- core/src/test/scala/spark/FileServerSuite.scala | 9 +++-- python/pyspark/__init__.py | 5 ++- python/pyspark/context.py | 40 +++++++++++++++++++--- python/pyspark/files.py | 24 +++++++++++++ python/pyspark/worker.py | 3 ++ python/run-tests | 3 ++ 16 files changed, 133 insertions(+), 51 deletions(-) create mode 100644 core/src/main/scala/spark/SparkFiles.java create mode 100644 python/pyspark/files.py (limited to 'core') diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala index 659d17718f..00901d95e2 100644 --- a/core/src/main/scala/spark/HttpFileServer.scala +++ b/core/src/main/scala/spark/HttpFileServer.scala @@ -1,9 +1,7 @@ package spark -import java.io.{File, PrintWriter} -import java.net.URL -import scala.collection.mutable.HashMap -import org.apache.hadoop.fs.FileUtil +import java.io.{File} +import com.google.common.io.Files private[spark] class HttpFileServer extends Logging { @@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging { } def addFileToDir(file: File, dir: File) : String = { - Utils.copyFile(file, new File(dir, file.getName)) + Files.copy(file, new File(dir, file.getName)) return dir + "/" + file.getName } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 8b6f4b3b7d..2eeca66ed6 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -439,9 +439,10 @@ class SparkContext( def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** - * Add a file to be downloaded into the working directory of this Spark job on every node. + * Add a file to be downloaded with this Spark job on every node. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(path)` to find its download location. */ def addFile(path: String) { val uri = new URI(path) @@ -454,7 +455,7 @@ class SparkContext( // Fetch the file locally in case a job is executed locally. // Jobs that run through LocalScheduler will already fetch the required dependencies, // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. - Utils.fetchFile(path, new File(".")) + Utils.fetchFile(path, new File(SparkFiles.getRootDirectory)) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 41441720a7..6b44e29f4c 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -28,14 +28,10 @@ class SparkEnv ( val broadcastManager: BroadcastManager, val blockManager: BlockManager, val connectionManager: ConnectionManager, - val httpFileServer: HttpFileServer + val httpFileServer: HttpFileServer, + val sparkFilesDir: String ) { - /** No-parameter constructor for unit tests. */ - def this() = { - this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null) - } - def stop() { httpFileServer.stop() mapOutputTracker.stop() @@ -112,6 +108,15 @@ object SparkEnv extends Logging { httpFileServer.initialize() System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) + // Set the sparkFiles directory, used when downloading dependencies. In local mode, + // this is a temporary directory; in distributed mode, this is the executor's current working + // directory. + val sparkFilesDir: String = if (isMaster) { + Utils.createTempDir().getAbsolutePath + } else { + "." + } + // Warn about deprecated spark.cache.class property if (System.getProperty("spark.cache.class") != null) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -128,6 +133,7 @@ object SparkEnv extends Logging { broadcastManager, blockManager, connectionManager, - httpFileServer) + httpFileServer, + sparkFilesDir) } } diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java new file mode 100644 index 0000000000..b59d8ce93f --- /dev/null +++ b/core/src/main/scala/spark/SparkFiles.java @@ -0,0 +1,25 @@ +package spark; + +import java.io.File; + +/** + * Resolves paths to files added through `addFile(). + */ +public class SparkFiles { + + private SparkFiles() {} + + /** + * Get the absolute path of a file added through `addFile()`. + */ + public static String get(String filename) { + return new File(getRootDirectory(), filename).getAbsolutePath(); + } + + /** + * Get the root directory that contains files added through `addFile()`. + */ + public static String getRootDirectory() { + return SparkEnv.get().sparkFilesDir(); + } +} \ No newline at end of file diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 692a3f4050..827c8bd81e 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -111,20 +111,6 @@ private object Utils extends Logging { } } - /** Copy a file on the local file system */ - def copyFile(source: File, dest: File) { - val in = new FileInputStream(source) - val out = new FileOutputStream(dest) - copyStream(in, out, true) - } - - /** Download a file from a given URL to the local filesystem */ - def downloadFile(url: URL, localPath: String) { - val in = url.openStream() - val out = new FileOutputStream(localPath) - Utils.copyStream(in, out, true) - } - /** * Download a file requested by the executor. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. @@ -201,7 +187,7 @@ private object Utils extends Logging { Utils.execute(Seq("tar", "-xf", filename), targetDir) } // Make the file executable - That's necessary for scripts - FileUtil.chmod(filename, "a+x") + FileUtil.chmod(targetFile.getAbsolutePath, "a+x") } /** diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index 16c122c584..50b8970cd8 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -323,9 +323,10 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def getSparkHome(): Option[String] = sc.getSparkHome() /** - * Add a file to be downloaded into the working directory of this Spark job on every node. + * Add a file to be downloaded with this Spark job on every node. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(path)` to find its download location. */ def addFile(path: String) { sc.addFile(path) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 5526406a20..f43a152ca7 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -67,6 +67,8 @@ private[spark] class PythonRDD[T: ClassManifest]( val dOut = new DataOutputStream(proc.getOutputStream) // Split index dOut.writeInt(split.index) + // sparkFilesDir + PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut) // Broadcast variables dOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index beceb55ecd..0d1fe2a6b4 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -106,11 +106,6 @@ private[spark] class ExecutorRunner( throw new IOException("Failed to create directory " + executorDir) } - // Download the files it depends on into it (disabled for now) - //for (url <- jobDesc.fileUrls) { - // fetchFile(url, executorDir) - //} - // Launch the process val command = buildCommandSeq() val builder = new ProcessBuilder(command: _*).directory(executorDir) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 2552958d27..70629f6003 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -162,16 +162,16 @@ private[spark] class Executor extends Logging { // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last - val url = new File(".", localName).toURI.toURL + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL if (!urlClassLoader.getURLs.contains(url)) { logInfo("Adding " + url + " to class loader") urlClassLoader.addURL(url) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index dff550036d..4451d314e6 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -116,16 +116,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last - val url = new File(".", localName).toURI.toURL + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL if (!classLoader.getURLs.contains(url)) { logInfo("Adding " + url + " to class loader") classLoader.addURL(url) diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala index b4283d9604..528c6b8424 100644 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -40,7 +40,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { sc.addFile(tmpFile.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { - val in = new BufferedReader(new FileReader("FileServerSuite.txt")) + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) val fileVal = in.readLine().toInt in.close() _ * fileVal + _ * fileVal @@ -54,7 +55,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { sc.addFile((new File(tmpFile.toString)).toURL.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { - val in = new BufferedReader(new FileReader("FileServerSuite.txt")) + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) val fileVal = in.readLine().toInt in.close() _ * fileVal + _ * fileVal @@ -83,7 +85,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { sc.addFile(tmpFile.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { - val in = new BufferedReader(new FileReader("FileServerSuite.txt")) + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) val fileVal = in.readLine().toInt in.close() _ * fileVal + _ * fileVal diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 00666bc0a3..3e8bca62f0 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -11,6 +11,8 @@ Public classes: A broadcast variable that gets reused across tasks. - L{Accumulator} An "add-only" shared variable that tasks can only add values to. + - L{SparkFiles} + Access files shipped with jobs. """ import sys import os @@ -19,6 +21,7 @@ sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.eg from pyspark.context import SparkContext from pyspark.rdd import RDD +from pyspark.files import SparkFiles -__all__ = ["SparkContext", "RDD"] +__all__ = ["SparkContext", "RDD", "SparkFiles"] diff --git a/python/pyspark/context.py b/python/pyspark/context.py index dcbed37270..ec0cc7c2f9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1,5 +1,7 @@ import os import atexit +import shutil +import tempfile from tempfile import NamedTemporaryFile from pyspark import accumulators @@ -173,10 +175,26 @@ class SparkContext(object): def addFile(self, path): """ - Add a file to be downloaded into the working directory of this Spark - job on every node. The C{path} passed can be either a local file, - a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, - HTTPS or FTP URI. + Add a file to be downloaded with this Spark job on every node. + The C{path} passed can be either a local file, a file in HDFS + (or other Hadoop-supported filesystems), or an HTTP, HTTPS or + FTP URI. + + To access the file in Spark jobs, use + L{SparkFiles.get(path)} to find its + download location. + + >>> from pyspark import SparkFiles + >>> path = os.path.join(tempdir, "test.txt") + >>> with open(path, "w") as testFile: + ... testFile.write("100") + >>> sc.addFile(path) + >>> def func(iterator): + ... with open(SparkFiles.get("test.txt")) as testFile: + ... fileVal = int(testFile.readline()) + ... return [x * 100 for x in iterator] + >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() + [100, 200, 300, 400] """ self._jsc.sc().addFile(path) @@ -211,3 +229,17 @@ class SparkContext(object): accidental overriding of checkpoint files in the existing directory. """ self._jsc.sc().setCheckpointDir(dirName, useExisting) + + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['tempdir'] = tempfile.mkdtemp() + atexit.register(lambda: shutil.rmtree(globs['tempdir'])) + doctest.testmod(globs=globs) + globs['sc'].stop() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/files.py b/python/pyspark/files.py new file mode 100644 index 0000000000..de1334f046 --- /dev/null +++ b/python/pyspark/files.py @@ -0,0 +1,24 @@ +import os + + +class SparkFiles(object): + """ + Resolves paths to files added through + L{addFile()}. + + SparkFiles contains only classmethods; users should not create SparkFiles + instances. + """ + + _root_directory = None + + def __init__(self): + raise NotImplementedError("Do not construct SparkFiles objects") + + @classmethod + def get(cls, filename): + """ + Get the absolute path of a file added through C{addFile()}. + """ + path = os.path.join(SparkFiles._root_directory, filename) + return os.path.abspath(path) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b2b9288089..e7bdb7682b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -8,6 +8,7 @@ from base64 import standard_b64decode from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler +from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file @@ -23,6 +24,8 @@ def load_obj(): def main(): split_index = read_int(sys.stdin) + spark_files_dir = load_pickle(read_with_length(sys.stdin)) + SparkFiles._root_directory = spark_files_dir num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): bid = read_long(sys.stdin) diff --git a/python/run-tests b/python/run-tests index ce214e98a8..a3a9ff5dcb 100755 --- a/python/run-tests +++ b/python/run-tests @@ -8,6 +8,9 @@ FAILED=0 $FWDIR/pyspark pyspark/rdd.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark pyspark/context.py +FAILED=$(($?||$FAILED)) + $FWDIR/pyspark -m doctest pyspark/broadcast.py FAILED=$(($?||$FAILED)) -- cgit v1.2.3 From 7b9e96c99206c0679d9925e0161fde738a5c7c3a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Jan 2013 16:45:00 -0800 Subject: Add synchronization to Executor.updateDependencies() (SPARK-662) --- core/src/main/scala/spark/executor/Executor.scala | 34 ++++++++++++----------- 1 file changed, 18 insertions(+), 16 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 70629f6003..28d9d40d43 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -159,22 +159,24 @@ private[spark] class Executor extends Logging { * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentFiles(name) = timestamp - } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!urlClassLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - urlClassLoader.addURL(url) + synchronized { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL + if (!urlClassLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + urlClassLoader.addURL(url) + } } } } -- cgit v1.2.3 From 2d8218b8717435a47d7cea399290b30bf5ef010b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 20:00:27 -0600 Subject: Remove unneeded/now-broken saveAsNewAPIHadoopFile overload. --- core/src/main/scala/spark/PairRDDFunctions.scala | 12 ------------ 1 file changed, 12 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 656b820b8a..53b051f1c5 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -485,18 +485,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } - /** - * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` - * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. - */ - def saveAsNewAPIHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { - saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass) - } - /** * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. -- cgit v1.2.3 From a8baeb93272b03a98e44c7bf5c541611aec4a64b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 21:30:24 -0600 Subject: Further simplify getOrElse call. --- core/src/main/scala/spark/ParallelCollection.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index ad23e5bec8..10adcd53ec 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -44,7 +44,7 @@ private[spark] class ParallelCollection[T: ClassManifest]( s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def getPreferredLocations(s: Split): Seq[String] = { - locationPrefs.get(s.index) getOrElse Nil + locationPrefs.getOrElse(s.index, Nil) } override def clearDependencies() { -- cgit v1.2.3 From 551a47a620c7dc207e3530e54d794a3c3aa8e45e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Jan 2013 23:31:00 -0800 Subject: Refactor daemon thread pool creation. --- .../src/main/scala/spark/DaemonThreadFactory.scala | 18 ------------ core/src/main/scala/spark/Utils.scala | 33 +++++----------------- .../scala/spark/network/ConnectionManager.scala | 5 ++-- .../spark/scheduler/local/LocalScheduler.scala | 2 +- .../spark/streaming/dstream/RawInputDStream.scala | 5 ++-- 5 files changed, 13 insertions(+), 50 deletions(-) delete mode 100644 core/src/main/scala/spark/DaemonThreadFactory.scala (limited to 'core') diff --git a/core/src/main/scala/spark/DaemonThreadFactory.scala b/core/src/main/scala/spark/DaemonThreadFactory.scala deleted file mode 100644 index 56e59adeb7..0000000000 --- a/core/src/main/scala/spark/DaemonThreadFactory.scala +++ /dev/null @@ -1,18 +0,0 @@ -package spark - -import java.util.concurrent.ThreadFactory - -/** - * A ThreadFactory that creates daemon threads - */ -private object DaemonThreadFactory extends ThreadFactory { - override def newThread(r: Runnable): Thread = new DaemonThread(r) -} - -private class DaemonThread(r: Runnable = null) extends Thread { - override def run() { - if (r != null) { - r.run() - } - } -} \ No newline at end of file diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 692a3f4050..9b8636f6c8 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -10,6 +10,7 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.io.Source import com.google.common.io.Files +import com.google.common.util.concurrent.ThreadFactoryBuilder /** * Various utility methods used by Spark. @@ -287,29 +288,14 @@ private object Utils extends Logging { customHostname.getOrElse(InetAddress.getLocalHost.getHostName) } - /** - * Returns a standard ThreadFactory except all threads are daemons. - */ - private def newDaemonThreadFactory: ThreadFactory = { - new ThreadFactory { - def newThread(r: Runnable): Thread = { - var t = Executors.defaultThreadFactory.newThread (r) - t.setDaemon (true) - return t - } - } - } + private[spark] val daemonThreadFactory: ThreadFactory = + new ThreadFactoryBuilder().setDaemon(true).build() /** * Wrapper over newCachedThreadPool. */ - def newDaemonCachedThreadPool(): ThreadPoolExecutor = { - var threadPool = Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory (newDaemonThreadFactory) - - return threadPool - } + def newDaemonCachedThreadPool(): ThreadPoolExecutor = + Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] /** * Return the string to tell how long has passed in seconds. The passing parameter should be in @@ -322,13 +308,8 @@ private object Utils extends Logging { /** * Wrapper over newFixedThreadPool. */ - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { - var threadPool = Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory(newDaemonThreadFactory) - - return threadPool - } + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = + Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] /** * Delete a file or directory and its contents recursively. diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 36c01ad629..2ecd14f536 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -52,9 +52,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] val sendMessageRequests = new Queue[(Message, SendingConnection)] - implicit val futureExecContext = ExecutionContext.fromExecutor( - Executors.newCachedThreadPool(DaemonThreadFactory)) - + implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) + var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null serverChannel.configureBlocking(false) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index dff550036d..87f8474ea0 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -20,7 +20,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon with Logging { var attemptId = new AtomicInteger(0) - var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) + var threadPool = Utils.newDaemonFixedThreadPool(threads) val env = SparkEnv.get var listener: TaskSchedulerListener = null diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala index 290fab1ce0..04e6b69b7b 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala @@ -1,6 +1,6 @@ package spark.streaming.dstream -import spark.{DaemonThread, Logging} +import spark.Logging import spark.storage.StorageLevel import spark.streaming.StreamingContext @@ -48,7 +48,8 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel) val queue = new ArrayBlockingQueue[ByteBuffer](2) - blockPushingThread = new DaemonThread { + blockPushingThread = new Thread { + setDaemon(true) override def run() { var nextBlockNumber = 0 while (true) { -- cgit v1.2.3 From e353886a8ca6179f25b4176d7a62b5d04ce79276 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 22 Jan 2013 00:23:31 -0800 Subject: Use generation numbers for fetch failure tracking --- .../main/scala/spark/scheduler/DAGScheduler.scala | 27 ++++++++++++++++------ 1 file changed, 20 insertions(+), 7 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 59f2099e91..39a1e6d6c6 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -72,8 +72,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val cacheTracker = env.cacheTracker val mapOutputTracker = env.mapOutputTracker - val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; - // that's not going to be a realistic assumption in general + // For tracking failed nodes, we use the MapOutputTracker's generation number, which is + // sent with every task. When we detect a node failing, we note the current generation number + // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask + // results. + // TODO: Garbage collect information about failure generations when new stages start. + val failedGeneration = new HashMap[String, Long] val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now @@ -429,7 +433,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val status = event.result.asInstanceOf[MapStatus] val host = status.address.ip logInfo("ShuffleMapTask finished with host " + host) - if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos + if (failedGeneration.contains(host) && smt.generation <= failedGeneration(host)) { + logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + host) + } else { stage.addOutputLoc(smt.partition, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { @@ -495,7 +501,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock // TODO: mark the host as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleHostLost(bmAddress.ip) + handleHostLost(bmAddress.ip, Some(task.generation)) } case other => @@ -507,11 +513,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with /** * Responds to a host being lost. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside. + * + * Optionally the generation during which the failure was caught can be passed to avoid allowing + * stray fetch failures from possibly retriggering the detection of a node as lost. */ - def handleHostLost(host: String) { - if (!deadHosts.contains(host)) { + def handleHostLost(host: String, maybeGeneration: Option[Long] = None) { + val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) + if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) { + failedGeneration(host) = currentGeneration logInfo("Host lost: " + host) - deadHosts += host env.blockManager.master.notifyADeadHost(host) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { @@ -519,6 +529,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray mapOutputTracker.registerMapOutputs(shuffleId, locs, true) } + if (shuffleToMapStage.isEmpty) { + mapOutputTracker.incrementGeneration() + } cacheTracker.cacheLost(host) updateCacheLocs() } -- cgit v1.2.3 From 7e9ee2e8335f085062d3fdeecd0b49ec63e92117 Mon Sep 17 00:00:00 2001 From: Leemoonsoo Date: Tue, 22 Jan 2013 23:08:34 +0900 Subject: Fix for hanging spark.HttpFileServer with kind of virtual network --- core/src/main/scala/spark/HttpServer.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala index 0196595ba1..4e0507c080 100644 --- a/core/src/main/scala/spark/HttpServer.scala +++ b/core/src/main/scala/spark/HttpServer.scala @@ -4,6 +4,7 @@ import java.io.File import java.net.InetAddress import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.bio.SocketConnector import org.eclipse.jetty.server.handler.DefaultHandler import org.eclipse.jetty.server.handler.HandlerList import org.eclipse.jetty.server.handler.ResourceHandler @@ -27,7 +28,13 @@ private[spark] class HttpServer(resourceBase: File) extends Logging { if (server != null) { throw new ServerStateException("Server is already started") } else { - server = new Server(0) + server = new Server() + val connector = new SocketConnector + connector.setMaxIdleTime(60*1000) + connector.setSoLingerTime(-1) + connector.setPort(0) + server.addConnector(connector) + val threadPool = new QueuedThreadPool threadPool.setDaemon(true) server.setThreadPool(threadPool) -- cgit v1.2.3 From 588b24197a85c4b46a38595007293abef9a41f2c Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 10:19:30 -0600 Subject: Use default arguments instead of constructor overloads. --- core/src/main/scala/spark/SparkContext.scala | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 8b6f4b3b7d..495d1b6c78 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -58,27 +58,11 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend class SparkContext( val master: String, val jobName: String, - val sparkHome: String, - val jars: Seq[String], - environment: Map[String, String]) + val sparkHome: String = null, + val jars: Seq[String] = Nil, + environment: Map[String, String] = Map()) extends Logging { - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI - * @param sparkHome Location where Spark is installed on cluster nodes. - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - */ - def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) = - this(master, jobName, sparkHome, jars, Map()) - - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI - */ - def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map()) - // Ensure logging is initialized before we spawn any threads initLogging() -- cgit v1.2.3 From 50e2b23927956c14db40093d31bc80892764006a Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 22 Jan 2013 09:27:33 -0800 Subject: Fix up some problems from the merge --- .../main/scala/spark/storage/BlockManagerMasterActor.scala | 11 +++++++++++ core/src/main/scala/spark/storage/BlockManagerMessages.scala | 3 +++ core/src/main/scala/spark/storage/StorageUtils.scala | 8 ++++---- 3 files changed, 18 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index f4d026da33..c945c34c71 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -68,6 +68,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { case GetMemoryStatus => getMemoryStatus + case GetStorageStatus => + getStorageStatus + case RemoveBlock(blockId) => removeBlock(blockId) @@ -177,6 +180,14 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! res } + private def getStorageStatus() { + val res = blockManagerInfo.map { case(blockManagerId, info) => + import collection.JavaConverters._ + StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap) + } + sender ! res + } + private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockManagerId + " " diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index d73a9b790f..3a381fd385 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -100,3 +100,6 @@ case object GetMemoryStatus extends ToBlockManagerMaster private[spark] case object ExpireDeadHosts extends ToBlockManagerMaster + +private[spark] +case object GetStorageStatus extends ToBlockManagerMaster \ No newline at end of file diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index ebc7390ee5..63ad5c125b 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -1,6 +1,7 @@ package spark.storage import spark.SparkContext +import BlockManagerMasterActor.BlockStatus private[spark] case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, @@ -20,8 +21,8 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, } -case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numPartitions: Int, memSize: Long, diskSize: Long, locations: Array[BlockManagerId]) +case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, + numPartitions: Int, memSize: Long, diskSize: Long) /* Helper methods for storage-related objects */ @@ -58,8 +59,7 @@ object StorageUtils { val rddName = Option(sc.persistentRdds.get(rddId).name).getOrElse(rddKey) val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel - RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize, - rddBlocks.map(_.blockManagerId)) + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize) }.toArray } -- cgit v1.2.3 From 27b3f3f0a980f86bac14a14516b5d52a32aa8cbb Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 15:30:42 -0600 Subject: Handle slaveLost before slaveIdToHost knows about it. --- .../spark/scheduler/cluster/ClusterScheduler.scala | 31 +++++++++++++--------- 1 file changed, 18 insertions(+), 13 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 20f6e65020..a639b72795 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -252,19 +252,24 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def slaveLost(slaveId: String, reason: ExecutorLossReason) { var failedHost: Option[String] = None synchronized { - val host = slaveIdToHost(slaveId) - if (hostsAlive.contains(host)) { - logError("Lost an executor on " + host + ": " + reason) - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - activeTaskSetsQueue.foreach(_.hostLost(host)) - failedHost = Some(host) - } else { - // We may get multiple slaveLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor on " + host + " (already removed): " + reason) + slaveIdToHost.get(slaveId) match { + case Some(host) => + if (hostsAlive.contains(host)) { + logError("Lost an executor on " + host + ": " + reason) + slaveIdsWithExecutors -= slaveId + hostsAlive -= host + activeTaskSetsQueue.foreach(_.hostLost(host)) + failedHost = Some(host) + } else { + // We may get multiple slaveLost() calls with different loss reasons. For example, one + // may be triggered by a dropped connection from the slave while another may be a report + // of executor termination from Mesos. We produce log messages for both so we eventually + // report the termination reason. + logError("Lost an executor on " + host + " (already removed): " + reason) + } + case None => + // We were told about a slave being lost before we could even allocate work to it + logError("Lost slave " + slaveId + " (no work assigned yet)") } } if (failedHost != None) { -- cgit v1.2.3 From 6f2194f7576eb188c23f18125f5101ae0b4e9e4d Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 15:38:58 -0600 Subject: Call removeJob instead of killing the cluster. --- core/src/main/scala/spark/deploy/master/Master.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 2c2cd0231b..d1a65204b8 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -103,8 +103,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor val e = new SparkException("Job %s wth ID %s failed %d times.".format( jobInfo.desc.name, jobInfo.id, jobInfo.retryCount)) logError(e.getMessage, e) - throw e - //System.exit(1) + removeJob(jobInfo) } } } -- cgit v1.2.3 From 250fe89679bb59ef0d31f74985f72556dcfe2d06 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 16:29:05 -0600 Subject: Handle Master telling the Worker to kill an already-dead executor. --- core/src/main/scala/spark/deploy/worker/Worker.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 19bf2be118..d040b86908 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -143,9 +143,13 @@ private[spark] class Worker( case KillExecutor(jobId, execId) => val fullId = jobId + "/" + execId - val executor = executors(fullId) - logInfo("Asked to kill executor " + fullId) - executor.kill() + executors.get(fullId) match { + case Some(executor) => + logInfo("Asked to kill executor " + fullId) + executor.kill() + case None => + logInfo("Asked to kill non-existent existent " + fullId) + } case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => masterDisconnected() -- cgit v1.2.3 From 8c51322cd05f2ae97a08c3af314c7608fcf71b57 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 18:09:10 -0600 Subject: Don't bother creating an exception. --- core/src/main/scala/spark/deploy/master/Master.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index d1a65204b8..361e5ac627 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -100,9 +100,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) { schedule() } else { - val e = new SparkException("Job %s wth ID %s failed %d times.".format( + logError("Job %s wth ID %s failed %d times, removing it".format( jobInfo.desc.name, jobInfo.id, jobInfo.retryCount)) - logError(e.getMessage, e) removeJob(jobInfo) } } -- cgit v1.2.3 From 98d0b7747d7539db009a9bbc261f899955871524 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 18:11:51 -0600 Subject: Fix Worker logInfo about unknown executor. --- core/src/main/scala/spark/deploy/worker/Worker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index d040b86908..5a83a42daf 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -148,7 +148,7 @@ private[spark] class Worker( logInfo("Asked to kill executor " + fullId) executor.kill() case None => - logInfo("Asked to kill non-existent existent " + fullId) + logInfo("Asked to kill unknown executor " + fullId) } case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => -- cgit v1.2.3 From 284993100022cc4bd43bf84a0be4dd91cf7a4ac0 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 22 Jan 2013 22:19:30 -0800 Subject: Eliminate CacheTracker. Replaces DAGScheduler's queries of CacheTracker with BlockManagerMaster queries. Adds CacheManager to locally coordinate computation of cached RDDs. --- core/src/main/scala/spark/CacheTracker.scala | 240 --------------------- core/src/main/scala/spark/RDD.scala | 2 +- core/src/main/scala/spark/SparkEnv.scala | 8 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 24 ++- .../main/scala/spark/storage/BlockManager.scala | 24 +-- core/src/test/scala/spark/CacheTrackerSuite.scala | 131 ----------- 6 files changed, 18 insertions(+), 411 deletions(-) delete mode 100644 core/src/main/scala/spark/CacheTracker.scala delete mode 100644 core/src/test/scala/spark/CacheTrackerSuite.scala (limited to 'core') diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala deleted file mode 100644 index 86ad737583..0000000000 --- a/core/src/main/scala/spark/CacheTracker.scala +++ /dev/null @@ -1,240 +0,0 @@ -package spark - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import akka.actor._ -import akka.dispatch._ -import akka.pattern.ask -import akka.remote._ -import akka.util.Duration -import akka.util.Timeout -import akka.util.duration._ - -import spark.storage.BlockManager -import spark.storage.StorageLevel -import util.{TimeStampedHashSet, MetadataCleaner, TimeStampedHashMap} - -private[spark] sealed trait CacheTrackerMessage - -private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L) - extends CacheTrackerMessage -private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L) - extends CacheTrackerMessage -private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage -private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage -private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage -private[spark] case object GetCacheStatus extends CacheTrackerMessage -private[spark] case object GetCacheLocations extends CacheTrackerMessage -private[spark] case object StopCacheTracker extends CacheTrackerMessage - -private[spark] class CacheTrackerActor extends Actor with Logging { - // TODO: Should probably store (String, CacheType) tuples - private val locs = new TimeStampedHashMap[Int, Array[List[String]]] - - /** - * A map from the slave's host name to its cache size. - */ - private val slaveCapacity = new HashMap[String, Long] - private val slaveUsage = new HashMap[String, Long] - - private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.clearOldValues) - - private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) - private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) - private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host) - - def receive = { - case SlaveCacheStarted(host: String, size: Long) => - slaveCapacity.put(host, size) - slaveUsage.put(host, 0) - sender ! true - - case RegisterRDD(rddId: Int, numPartitions: Int) => - logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions") - locs(rddId) = Array.fill[List[String]](numPartitions)(Nil) - sender ! true - - case AddedToCache(rddId, partition, host, size) => - slaveUsage.put(host, getCacheUsage(host) + size) - locs(rddId)(partition) = host :: locs(rddId)(partition) - sender ! true - - case DroppedFromCache(rddId, partition, host, size) => - slaveUsage.put(host, getCacheUsage(host) - size) - // Do a sanity check to make sure usage is greater than 0. - locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host) - sender ! true - - case MemoryCacheLost(host) => - logInfo("Memory cache lost on " + host) - for ((id, locations) <- locs) { - for (i <- 0 until locations.length) { - locations(i) = locations(i).filterNot(_ == host) - } - } - sender ! true - - case GetCacheLocations => - logInfo("Asked for current cache locations") - sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())} - - case GetCacheStatus => - val status = slaveCapacity.map { case (host, capacity) => - (host, capacity, getCacheUsage(host)) - }.toSeq - sender ! status - - case StopCacheTracker => - logInfo("Stopping CacheTrackerActor") - sender ! true - metadataCleaner.cancel() - context.stop(self) - } -} - -private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager) - extends Logging { - - // Tracker actor on the master, or remote reference to it on workers - val ip: String = System.getProperty("spark.master.host", "localhost") - val port: Int = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "CacheTracker" - - val timeout = 10.seconds - - var trackerActor: ActorRef = if (isMaster) { - val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName) - logInfo("Registered CacheTrackerActor actor") - actor - } else { - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - actorSystem.actorFor(url) - } - - // TODO: Consider removing this HashSet completely as locs CacheTrackerActor already - // keeps track of registered RDDs - val registeredRddIds = new TimeStampedHashSet[Int] - - // Remembers which splits are currently being loaded (on worker nodes) - val loading = new HashSet[String] - - val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.clearOldValues) - - // Send a message to the trackerActor and get its result within a default timeout, or - // throw a SparkException if this fails. - def askTracker(message: Any): Any = { - try { - val future = trackerActor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with CacheTracker", e) - } - } - - // Send a one-way message to the trackerActor, to which we expect it to reply with true. - def communicate(message: Any) { - if (askTracker(message) != true) { - throw new SparkException("Error reply received from CacheTracker") - } - } - - // Registers an RDD (on master only) - def registerRDD(rddId: Int, numPartitions: Int) { - registeredRddIds.synchronized { - if (!registeredRddIds.contains(rddId)) { - logInfo("Registering RDD ID " + rddId + " with cache") - registeredRddIds += rddId - communicate(RegisterRDD(rddId, numPartitions)) - } - } - } - - // For BlockManager.scala only - def cacheLost(host: String) { - communicate(MemoryCacheLost(host)) - logInfo("CacheTracker successfully removed entries on " + host) - } - - // Get the usage status of slave caches. Each tuple in the returned sequence - // is in the form of (host name, capacity, usage). - def getCacheStatus(): Seq[(String, Long, Long)] = { - askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]] - } - - // For BlockManager.scala only - def notifyFromBlockManager(t: AddedToCache) { - communicate(t) - } - - // Get a snapshot of the currently known locations - def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { - askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]] - } - - // Gets or computes an RDD split - def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) - : Iterator[T] = { - val key = "rdd_%d_%d".format(rdd.id, split.index) - logInfo("Cache key is " + key) - blockManager.get(key) match { - case Some(cachedValues) => - // Split is in cache, so just return its values - logInfo("Found partition in cache!") - return cachedValues.asInstanceOf[Iterator[T]] - - case None => - // Mark the split as loading (unless someone else marks it first) - loading.synchronized { - if (loading.contains(key)) { - logInfo("Loading contains " + key + ", waiting...") - while (loading.contains(key)) { - try {loading.wait()} catch {case _ =>} - } - logInfo("Loading no longer contains " + key + ", so returning cached result") - // See whether someone else has successfully loaded it. The main way this would fail - // is for the RDD-level cache eviction policy if someone else has loaded the same RDD - // partition but we didn't want to make space for it. However, that case is unlikely - // because it's unlikely that two threads would work on the same RDD partition. One - // downside of the current code is that threads wait serially if this does happen. - blockManager.get(key) match { - case Some(values) => - return values.asInstanceOf[Iterator[T]] - case None => - logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") - loading.add(key) - } - } else { - loading.add(key) - } - } - try { - // If we got here, we have to load the split - val elements = new ArrayBuffer[Any] - logInfo("Computing partition " + split) - elements ++= rdd.compute(split, context) - // Try to put this block in the blockManager - blockManager.put(key, elements, storageLevel, true) - return elements.iterator.asInstanceOf[Iterator[T]] - } finally { - loading.synchronized { - loading.remove(key) - loading.notifyAll() - } - } - } - } - - // Called by the Cache to report that an entry has been dropped from it - def dropEntry(rddId: Int, partition: Int) { - communicate(DroppedFromCache(rddId, partition, Utils.localHostName())) - } - - def stop() { - communicate(StopCacheTracker) - registeredRddIds.clear() - trackerActor = null - } -} diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e0d2eabb1d..c79f34342f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -176,7 +176,7 @@ abstract class RDD[T: ClassManifest]( if (isCheckpointed) { checkpointData.get.iterator(split, context) } else if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel) + SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) } else { compute(split, context) } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 41441720a7..a080194980 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -22,7 +22,7 @@ class SparkEnv ( val actorSystem: ActorSystem, val serializer: Serializer, val closureSerializer: Serializer, - val cacheTracker: CacheTracker, + val cacheManager: CacheManager, val mapOutputTracker: MapOutputTracker, val shuffleFetcher: ShuffleFetcher, val broadcastManager: BroadcastManager, @@ -39,7 +39,6 @@ class SparkEnv ( def stop() { httpFileServer.stop() mapOutputTracker.stop() - cacheTracker.stop() shuffleFetcher.stop() broadcastManager.stop() blockManager.stop() @@ -100,8 +99,7 @@ object SparkEnv extends Logging { val closureSerializer = instantiateClass[Serializer]( "spark.closure.serializer", "spark.JavaSerializer") - val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager) - blockManager.cacheTracker = cacheTracker + val cacheManager = new CacheManager(blockManager) val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster) @@ -122,7 +120,7 @@ object SparkEnv extends Logging { actorSystem, serializer, closureSerializer, - cacheTracker, + cacheManager, mapOutputTracker, shuffleFetcher, broadcastManager, diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 59f2099e91..03d173ac3b 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -69,8 +69,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with var cacheLocs = new HashMap[Int, Array[List[String]]] val env = SparkEnv.get - val cacheTracker = env.cacheTracker val mapOutputTracker = env.mapOutputTracker + val blockManagerMaster = env.blockManager.master val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; // that's not going to be a realistic assumption in general @@ -95,11 +95,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with }.start() def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { + if (!cacheLocs.contains(rdd.id)) { + val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray + cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { + locations => locations.map(_.ip).toList + }.toArray + } cacheLocs(rdd.id) } - def updateCacheLocs() { - cacheLocs = cacheTracker.getLocationsSnapshot() + def clearCacheLocs() { + cacheLocs.clear } /** @@ -126,7 +132,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of splits is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") - cacheTracker.registerRDD(rdd.id, rdd.splits.size) if (shuffleDep != None) { mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) } @@ -148,8 +153,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with visited += r // Kind of ugly: need to register RDDs with the cache here since // we can't do it in its constructor because # of splits is unknown - logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")") - cacheTracker.registerRDD(r.id, r.splits.size) for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => @@ -250,7 +253,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val runId = nextRunId.getAndIncrement() val finalStage = newStage(finalRDD, None, runId) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) - updateCacheLocs() + clearCacheLocs() logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + " output partitions") logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") @@ -293,7 +296,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // on the failed node. if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { logInfo("Resubmitting failed stages") - updateCacheLocs() + clearCacheLocs() val failed2 = failed.toArray failed.clear() for (stage <- failed2.sortBy(_.priority)) { @@ -443,7 +446,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with stage.shuffleDep.get.shuffleId, stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray) } - updateCacheLocs() + clearCacheLocs() if (stage.outputLocs.count(_ == Nil) != 0) { // Some tasks had failed; let's resubmit this stage // TODO: Lower-level scheduler should also deal with this @@ -519,8 +522,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray mapOutputTracker.registerMapOutputs(shuffleId, locs, true) } - cacheTracker.cacheLost(host) - updateCacheLocs() + clearCacheLocs() } } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 7a8ac10cdd..e049565f48 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -16,7 +16,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils} +import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils} import spark.network._ import spark.serializer.Serializer import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} @@ -71,9 +71,6 @@ class BlockManager( val connectionManagerId = connectionManager.id val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) - // TODO: This will be removed after cacheTracker is removed from the code base. - var cacheTracker: CacheTracker = null - // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) val maxBytesInFlight = @@ -662,10 +659,6 @@ class BlockManager( BlockManager.dispose(bytesAfterPut) - // TODO: This code will be removed when CacheTracker is gone. - if (blockId.startsWith("rdd")) { - notifyCacheTracker(blockId) - } logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)) return size @@ -733,11 +726,6 @@ class BlockManager( } } - // TODO: This code will be removed when CacheTracker is gone. - if (blockId.startsWith("rdd")) { - notifyCacheTracker(blockId) - } - // If replication had started, then wait for it to finish if (level.replication > 1) { if (replicationFuture == null) { @@ -780,16 +768,6 @@ class BlockManager( } } - // TODO: This code will be removed when CacheTracker is gone. - private def notifyCacheTracker(key: String) { - if (cacheTracker != null) { - val rddInfo = key.split("_") - val rddId: Int = rddInfo(1).toInt - val partition: Int = rddInfo(2).toInt - cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host)) - } - } - /** * Read a block consisting of a single object. */ diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala deleted file mode 100644 index 467605981b..0000000000 --- a/core/src/test/scala/spark/CacheTrackerSuite.scala +++ /dev/null @@ -1,131 +0,0 @@ -package spark - -import org.scalatest.FunSuite - -import scala.collection.mutable.HashMap - -import akka.actor._ -import akka.dispatch._ -import akka.pattern.ask -import akka.remote._ -import akka.util.Duration -import akka.util.Timeout -import akka.util.duration._ - -class CacheTrackerSuite extends FunSuite { - // Send a message to an actor and wait for a reply, in a blocking manner - private def ask(actor: ActorRef, message: Any): Any = { - try { - val timeout = 10.seconds - val future = actor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with actor", e) - } - } - - test("CacheTrackerActor slave initialization & cache status") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 0L))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - test("RegisterRDD") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, RegisterRDD(1, 3)) === true) - assert(ask(tracker, RegisterRDD(2, 1)) === true) - - assert(getCacheLocations(tracker) === Map(1 -> List(Nil, Nil, Nil), 2 -> List(Nil))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - test("AddedToCache") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, RegisterRDD(1, 2)) === true) - assert(ask(tracker, RegisterRDD(2, 1)) === true) - - assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true) - assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true) - assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L))) - - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - test("DroppedFromCache") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, RegisterRDD(1, 2)) === true) - assert(ask(tracker, RegisterRDD(2, 1)) === true) - - assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true) - assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true) - assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L))) - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - - assert(ask(tracker, DroppedFromCache(1, 1, "host001", 2L << 11)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 68608L))) - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"),List()), 2 -> List(List("host001")))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - /** - * Helper function to get cacheLocations from CacheTracker - */ - def getCacheLocations(tracker: ActorRef): HashMap[Int, List[List[String]]] = { - val answer = ask(tracker, GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]] - answer.map { case (i, arr) => (i, arr.toList) } - } -} -- cgit v1.2.3 From 43e9ff959645e533bcfa0a5c31e62e32c7e9d0a6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Jan 2013 22:47:26 -0800 Subject: Add test for driver hanging on exit (SPARK-530). --- core/src/test/scala/spark/DriverSuite.scala | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 core/src/test/scala/spark/DriverSuite.scala (limited to 'core') diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala new file mode 100644 index 0000000000..70a7c8bc2f --- /dev/null +++ b/core/src/test/scala/spark/DriverSuite.scala @@ -0,0 +1,31 @@ +package spark + +import java.io.File + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts +import org.scalatest.prop.TableDrivenPropertyChecks._ +import org.scalatest.time.SpanSugar._ + +class DriverSuite extends FunSuite with Timeouts { + test("driver should exit after finishing") { + // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" + val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) + forAll(masters) { (master: String) => + failAfter(10 seconds) { + Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME"))) + } + } + } +} + +/** + * Program that creates a Spark driver but doesn't call SparkContext.stop() or + * Sys.exit() after finishing. + */ +object DriverWithoutCleanup { + def main(args: Array[String]) { + val sc = new SparkContext(args(0), "DriverWithoutCleanup") + sc.parallelize(1 to 100, 4).count() + } +} \ No newline at end of file -- cgit v1.2.3 From bacade6caf7527737dc6f02b1c2ca9114e02d8bc Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Jan 2013 22:55:26 -0800 Subject: Modified BlockManagerId API to ensure zero duplicate objects. Fixed BlockManagerId testcase in BlockManagerTestSuite. --- .../src/main/scala/spark/scheduler/MapStatus.scala | 2 +- .../main/scala/spark/storage/BlockManager.scala | 2 +- .../main/scala/spark/storage/BlockManagerId.scala | 33 ++++++++++++++++++---- .../scala/spark/storage/BlockManagerMessages.scala | 3 +- .../test/scala/spark/MapOutputTrackerSuite.scala | 22 +++++++-------- .../scala/spark/storage/BlockManagerSuite.scala | 18 ++++++------ 6 files changed, 51 insertions(+), 29 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala index 4532d9497f..fae643f3a8 100644 --- a/core/src/main/scala/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/spark/scheduler/MapStatus.scala @@ -20,7 +20,7 @@ private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: } def readExternal(in: ObjectInput) { - address = new BlockManagerId(in) + address = BlockManagerId(in) compressedSizes = new Array[Byte](in.readInt()) in.readFully(compressedSizes) } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 7a8ac10cdd..596a69c583 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -69,7 +69,7 @@ class BlockManager( implicit val futureExecContext = connectionManager.futureExecContext val connectionManagerId = connectionManager.id - val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) + val blockManagerId = BlockManagerId(connectionManagerId.host, connectionManagerId.port) // TODO: This will be removed after cacheTracker is removed from the code base. var cacheTracker: CacheTracker = null diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index 488679f049..26c98f2ac8 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -3,20 +3,35 @@ package spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap +/** + * This class represent an unique identifier for a BlockManager. + * The first 2 constructors of this class is made private to ensure that + * BlockManagerId objects can be created only using the factory method in + * [[spark.storage.BlockManager$]]. This allows de-duplication of id objects. + * Also, constructor parameters are private to ensure that parameters cannot + * be modified from outside this class. + */ +private[spark] class BlockManagerId private ( + private var ip_ : String, + private var port_ : Int + ) extends Externalizable { + + private def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) -private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { def this() = this(null, 0) // For deserialization only - def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) + def ip = ip_ + + def port = port_ override def writeExternal(out: ObjectOutput) { - out.writeUTF(ip) - out.writeInt(port) + out.writeUTF(ip_) + out.writeInt(port_) } override def readExternal(in: ObjectInput) { - ip = in.readUTF() - port = in.readInt() + ip_ = in.readUTF() + port_ = in.readInt() } @throws(classOf[IOException]) @@ -35,6 +50,12 @@ private[spark] class BlockManagerId(var ip: String, var port: Int) extends Exter private[spark] object BlockManagerId { + def apply(ip: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(ip, port)) + + def apply(in: ObjectInput) = + getCachedBlockManagerId(new BlockManagerId(in)) + val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index d73a9b790f..7437fc63eb 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -54,8 +54,7 @@ class UpdateBlockInfo( } override def readExternal(in: ObjectInput) { - blockManagerId = new BlockManagerId() - blockManagerId.readExternal(in) + blockManagerId = BlockManagerId(in) blockId = in.readUTF() storageLevel = new StorageLevel() storageLevel.readExternal(in) diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index d3dd3a8fa4..095f415978 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -47,13 +47,13 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), Array(compressedSize10000, compressedSize1000))) val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((new BlockManagerId("hostA", 1000), size1000), - (new BlockManagerId("hostB", 1000), size10000))) + assert(statuses.toSeq === Seq((BlockManagerId("hostA", 1000), size1000), + (BlockManagerId("hostB", 1000), size10000))) tracker.stop() } @@ -65,14 +65,14 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) // As if we had two simulatenous fetch failures - tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) - tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) // The remaining reduce task might try to grab the output dispite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the @@ -95,13 +95,13 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize1000 = MapOutputTracker.compressSize(1000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) masterTracker.registerMapOutput(10, 0, new MapStatus( - new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + BlockManagerId("hostA", 1000), Array(compressedSize1000))) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((new BlockManagerId("hostA", 1000), size1000))) + Seq((BlockManagerId("hostA", 1000), size1000))) - masterTracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + masterTracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 8f86e3170e..a33d3324ba 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -82,16 +82,18 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("BlockManagerId object caching") { - val id1 = new StorageLevel(false, false, false, 3) - val id2 = new StorageLevel(false, false, false, 3) + val id1 = BlockManagerId("XXX", 1) + val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1 + assert(id2 === id1, "id2 is not same as id1") + assert(id2.eq(id1), "id2 is not the same object as id1") val bytes1 = spark.Utils.serialize(id1) - val id1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1) val bytes2 = spark.Utils.serialize(id2) - val id2_ = spark.Utils.deserialize[StorageLevel](bytes2) - assert(id1_ === id1, "Deserialized id1 not same as original id1") - assert(id2_ === id2, "Deserialized id2 not same as original id1") - assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2") - assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1") + val id2_ = spark.Utils.deserialize[BlockManagerId](bytes2) + assert(id1_ === id1, "Deserialized id1 is not same as original id1") + assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1") + assert(id2_ === id2, "Deserialized id2 is not same as original id2") + assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") } test("master + 1 manager interaction") { -- cgit v1.2.3 From 5e11f1e51f17113abb8d3a5bc261af5ba5ffce94 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Jan 2013 23:42:53 -0800 Subject: Modified StorageLevel API to ensure zero duplicate objects. --- .../main/scala/spark/storage/BlockManager.scala | 5 +-- .../main/scala/spark/storage/BlockMessage.scala | 2 +- .../main/scala/spark/storage/StorageLevel.scala | 47 ++++++++++++++-------- .../scala/spark/storage/BlockManagerSuite.scala | 16 +++++--- 4 files changed, 44 insertions(+), 26 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 596a69c583..ca7eb13ec8 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -191,7 +191,7 @@ class BlockManager( case level => val inMem = level.useMemory && memoryStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) - val storageLevel = new StorageLevel(onDisk, inMem, level.deserialized, level.replication) + val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication) val memSize = if (inMem) memoryStore.getSize(blockId) else 0L val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L (storageLevel, memSize, diskSize, info.tellMaster) @@ -760,8 +760,7 @@ class BlockManager( */ var cachedPeers: Seq[BlockManagerId] = null private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { - val tLevel: StorageLevel = - new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) + val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) if (cachedPeers == null) { cachedPeers = master.getPeers(blockManagerId, level.replication - 1) } diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala index 3f234df654..30d7500e01 100644 --- a/core/src/main/scala/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/spark/storage/BlockMessage.scala @@ -64,7 +64,7 @@ private[spark] class BlockMessage() { val booleanInt = buffer.getInt() val replication = buffer.getInt() - level = new StorageLevel(booleanInt, replication) + level = StorageLevel(booleanInt, replication) val dataLength = buffer.getInt() data = ByteBuffer.allocate(dataLength) diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index e3544e5aae..f2535ae5ae 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -7,25 +7,30 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for - * commonly useful storage levels. + * commonly useful storage levels. The recommended method to create your own storage level + * object is to use `StorageLevel.apply(...)` from the singleton object. */ class StorageLevel( - var useDisk: Boolean, - var useMemory: Boolean, - var deserialized: Boolean, - var replication: Int = 1) + private var useDisk_ : Boolean, + private var useMemory_ : Boolean, + private var deserialized_ : Boolean, + private var replication_ : Int = 1) extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - - assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") - - def this(flags: Int, replication: Int) { + private def this(flags: Int, replication: Int) { this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } def this() = this(false, true, false) // For deserialization + def useDisk = useDisk_ + def useMemory = useMemory_ + def deserialized = deserialized_ + def replication = replication_ + + assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + override def clone(): StorageLevel = new StorageLevel( this.useDisk, this.useMemory, this.deserialized, this.replication) @@ -43,13 +48,13 @@ class StorageLevel( def toInt: Int = { var ret = 0 - if (useDisk) { + if (useDisk_) { ret |= 4 } - if (useMemory) { + if (useMemory_) { ret |= 2 } - if (deserialized) { + if (deserialized_) { ret |= 1 } return ret @@ -57,15 +62,15 @@ class StorageLevel( override def writeExternal(out: ObjectOutput) { out.writeByte(toInt) - out.writeByte(replication) + out.writeByte(replication_) } override def readExternal(in: ObjectInput) { val flags = in.readByte() - useDisk = (flags & 4) != 0 - useMemory = (flags & 2) != 0 - deserialized = (flags & 1) != 0 - replication = in.readByte() + useDisk_ = (flags & 4) != 0 + useMemory_ = (flags & 2) != 0 + deserialized_ = (flags & 1) != 0 + replication_ = in.readByte() } @throws(classOf[IOException]) @@ -91,6 +96,14 @@ object StorageLevel { val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) + /** Create a new StorageLevel object */ + def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) = + getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication)) + + /** Create a new StorageLevel object from its integer representation */ + def apply(flags: Int, replication: Int) = + getCachedStorageLevel(new StorageLevel(flags, replication)) + private[spark] val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index a33d3324ba..a1aeb12f25 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -69,23 +69,29 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("StorageLevel object caching") { - val level1 = new StorageLevel(false, false, false, 3) - val level2 = new StorageLevel(false, false, false, 3) + val level1 = StorageLevel(false, false, false, 3) + val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1 + val level3 = StorageLevel(false, false, false, 2) // this should return a different object + assert(level2 === level1, "level2 is not same as level1") + assert(level2.eq(level1), "level2 is not the same object as level1") + assert(level3 != level1, "level3 is same as level1") val bytes1 = spark.Utils.serialize(level1) val level1_ = spark.Utils.deserialize[StorageLevel](bytes1) val bytes2 = spark.Utils.serialize(level2) val level2_ = spark.Utils.deserialize[StorageLevel](bytes2) assert(level1_ === level1, "Deserialized level1 not same as original level1") - assert(level2_ === level2, "Deserialized level2 not same as original level1") - assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2") - assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1") + assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2") + assert(level2_ === level2, "Deserialized level2 not same as original level2") + assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1") } test("BlockManagerId object caching") { val id1 = BlockManagerId("XXX", 1) val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1 + val id3 = BlockManagerId("XXX", 2) // this should return a different object assert(id2 === id1, "id2 is not same as id1") assert(id2.eq(id1), "id2 is not the same object as id1") + assert(id3 != id1, "id3 is same as id1") val bytes1 = spark.Utils.serialize(id1) val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1) val bytes2 = spark.Utils.serialize(id2) -- cgit v1.2.3 From 155f31398dc83ecb88b4b3e07849a2a8a0a6592f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 23 Jan 2013 01:10:26 -0800 Subject: Made StorageLevel constructor private, and added StorageLevels.create() to the Java API. Updates scala and java programming guides. --- core/src/main/scala/spark/api/java/StorageLevels.java | 11 +++++++++++ core/src/main/scala/spark/storage/StorageLevel.scala | 6 +++--- docs/java-programming-guide.md | 3 ++- docs/scala-programming-guide.md | 3 ++- 4 files changed, 18 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/StorageLevels.java b/core/src/main/scala/spark/api/java/StorageLevels.java index 722af3c06c..5e5845ac3a 100644 --- a/core/src/main/scala/spark/api/java/StorageLevels.java +++ b/core/src/main/scala/spark/api/java/StorageLevels.java @@ -17,4 +17,15 @@ public class StorageLevels { public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2); public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1); public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2); + + /** + * Create a new StorageLevel object. + * @param useDisk saved to disk, if true + * @param useMemory saved to memory, if true + * @param deserialized saved as deserialized objects, if true + * @param replication replication factor + */ + public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) { + return StorageLevel.apply(useDisk, useMemory, deserialized, replication); + } } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index f2535ae5ae..45d6ea2656 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -7,10 +7,10 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for - * commonly useful storage levels. The recommended method to create your own storage level - * object is to use `StorageLevel.apply(...)` from the singleton object. + * commonly useful storage levels. To create your own storage level object, use the factor method + * of the singleton object (`StorageLevel(...)`). */ -class StorageLevel( +class StorageLevel private( private var useDisk_ : Boolean, private var useMemory_ : Boolean, private var deserialized_ : Boolean, diff --git a/docs/java-programming-guide.md b/docs/java-programming-guide.md index 188ca4995e..37a906ea1c 100644 --- a/docs/java-programming-guide.md +++ b/docs/java-programming-guide.md @@ -75,7 +75,8 @@ class has a single abstract method, `call()`, that must be implemented. ## Storage Levels RDD [storage level](scala-programming-guide.html#rdd-persistence) constants, such as `MEMORY_AND_DISK`, are -declared in the [spark.api.java.StorageLevels](api/core/index.html#spark.api.java.StorageLevels) class. +declared in the [spark.api.java.StorageLevels](api/core/index.html#spark.api.java.StorageLevels) class. To +define your own storage level, you can use StorageLevels.create(...). # Other Features diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md index 7350eca837..301b330a79 100644 --- a/docs/scala-programming-guide.md +++ b/docs/scala-programming-guide.md @@ -301,7 +301,8 @@ We recommend going through the following process to select one: * Use the replicated storage levels if you want fast fault recovery (e.g. if using Spark to serve requests from a web application). *All* the storage levels provide full fault tolerance by recomputing lost data, but the replicated ones let you continue running tasks on the RDD without waiting to recompute a lost partition. - + +If you want to define your own storage level (say, with replication factor of 3 instead of 2), then use the function factor method `apply()` of the [`StorageLevel`](api/core/index.html#spark.storage.StorageLevel$) singleton object. # Shared Variables -- cgit v1.2.3 From 9a27062260490336a3bfa97c6efd39b1e7e81573 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 01:34:44 -0800 Subject: Force generation increment after shuffle map stage --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 39a1e6d6c6..d8a9049e81 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -445,9 +445,16 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with logInfo("waiting: " + waiting) logInfo("failed: " + failed) if (stage.shuffleDep != None) { + // We supply true to increment the generation number here in case this is a + // recomputation of the map outputs. In that case, some nodes may have cached + // locations with holes (from when we detected the error) and will need the + // generation incremented to refetch them. + // TODO: Only increment the generation number if this is not the first time + // we registered these map outputs. mapOutputTracker.registerMapOutputs( stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray) + stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + true) } updateCacheLocs() if (stage.outputLocs.count(_ == Nil) != 0) { -- cgit v1.2.3 From d209b6b7641059610f734414ea05e0494b5510b0 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 01:35:14 -0800 Subject: Extra debugging from hostLost() --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index d8a9049e81..740aec2e61 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -528,7 +528,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) { failedGeneration(host) = currentGeneration - logInfo("Host lost: " + host) + logInfo("Host lost: " + host + " (generation " + currentGeneration + ")") env.blockManager.master.notifyADeadHost(host) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { @@ -541,6 +541,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } cacheTracker.cacheLost(host) updateCacheLocs() + } else { + logDebug("Additional host lost message for " + host + + "(generation " + currentGeneration + ")") } } -- cgit v1.2.3 From 0b506dd2ecec909cd514143389d0846db2d194ed Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 01:37:51 -0800 Subject: Add tests of various node failure scenarios. --- core/src/test/scala/spark/DistributedSuite.scala | 72 ++++++++++++++++++++++++ 1 file changed, 72 insertions(+) (limited to 'core') diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index cacc2796b6..0d6b265e54 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -188,4 +188,76 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect() assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE")) } + + test("recover from node failures") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(Seq(true, true), 2) + val singleton = sc.parallelize(Seq(true), 1) + assert(data.count === 2) // force executors to start + val masterId = SparkEnv.get.blockManager.blockManagerId + assert(data.map(markNodeIfIdentity).collect.size === 2) + assert(data.map(failOnMarkedIdentity).collect.size === 2) + } + + test("recover from repeated node failures during shuffle-map") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, false), 2) + val singleton = sc.parallelize(Seq(false), 1) + assert(data.count === 2) + assert(data.map(markNodeIfIdentity).collect.size === 2) + assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2) + } + } + + test("recover from repeated node failures during shuffle-reduce") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, true), 2) + val singleton = sc.parallelize(Seq(false), 1) + assert(data.count === 2) + assert(data.map(markNodeIfIdentity).collect.size === 2) + // This relies on mergeCombiners being used to perform the actual reduce for this + // test to actually be testing what it claims. + val grouped = data.map(x => x -> x).combineByKey( + x => x, + (x: Boolean, y: Boolean) => x, + (x: Boolean, y: Boolean) => failOnMarkedIdentity(x) + ) + assert(grouped.collect.size === 1) + } + } +} + +object DistributedSuite { + // Indicates whether this JVM is marked for failure. + var mark = false + + // Set by test to remember if we are in the driver program so we can assert + // that we are not. + var amMaster = false + + // Act like an identity function, but if the argument is true, set mark to true. + def markNodeIfIdentity(item: Boolean): Boolean = { + if (item) { + assert(!amMaster) + mark = true + } + item + } + + // Act like an identity function, but if mark was set to true previously, fail, + // crashing the entire JVM. + def failOnMarkedIdentity(item: Boolean): Boolean = { + if (mark) { + System.exit(42) + } + item + } } -- cgit v1.2.3 From 79d55700ce2559051ac61cc2fb72a67fd7035926 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 23 Jan 2013 01:57:09 -0800 Subject: One more fix. Made even default constructor of BlockManagerId private to prevent such problems in the future. --- core/src/main/scala/spark/storage/BlockManagerId.scala | 11 ++++++----- core/src/main/scala/spark/storage/BlockManagerMessages.scala | 3 +-- core/src/main/scala/spark/storage/StorageLevel.scala | 7 +++++++ 3 files changed, 14 insertions(+), 7 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index 26c98f2ac8..abb8b45a1f 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -16,9 +16,7 @@ private[spark] class BlockManagerId private ( private var port_ : Int ) extends Externalizable { - private def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) - - def this() = this(null, 0) // For deserialization only + private def this() = this(null, 0) // For deserialization only def ip = ip_ @@ -53,8 +51,11 @@ private[spark] object BlockManagerId { def apply(ip: String, port: Int) = getCachedBlockManagerId(new BlockManagerId(ip, port)) - def apply(in: ObjectInput) = - getCachedBlockManagerId(new BlockManagerId(in)) + def apply(in: ObjectInput) = { + val obj = new BlockManagerId() + obj.readExternal(in) + getCachedBlockManagerId(obj) + } val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index 7437fc63eb..30483b0b37 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -56,8 +56,7 @@ class UpdateBlockInfo( override def readExternal(in: ObjectInput) { blockManagerId = BlockManagerId(in) blockId = in.readUTF() - storageLevel = new StorageLevel() - storageLevel.readExternal(in) + storageLevel = StorageLevel(in) memSize = in.readInt() diskSize = in.readInt() } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index 45d6ea2656..d1d1c61c1c 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -104,6 +104,13 @@ object StorageLevel { def apply(flags: Int, replication: Int) = getCachedStorageLevel(new StorageLevel(flags, replication)) + /** Read StorageLevel object from ObjectInput stream */ + def apply(in: ObjectInput) = { + val obj = new StorageLevel() + obj.readExternal(in) + getCachedStorageLevel(obj) + } + private[spark] val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() -- cgit v1.2.3 From ae2ed2947d43860c74a8d40767e289ca78073977 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 23 Jan 2013 10:36:18 -0800 Subject: Allow PySpark's SparkFiles to be used from driver Fix minor documentation formatting issues. --- core/src/main/scala/spark/SparkFiles.java | 8 ++++---- python/pyspark/context.py | 27 +++++++++++++++++++++------ python/pyspark/files.py | 20 +++++++++++++++++--- python/pyspark/tests.py | 23 +++++++++++++++++++++++ python/pyspark/worker.py | 1 + python/test_support/hello.txt | 1 + 6 files changed, 67 insertions(+), 13 deletions(-) create mode 100755 python/test_support/hello.txt (limited to 'core') diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java index b59d8ce93f..566aec622c 100644 --- a/core/src/main/scala/spark/SparkFiles.java +++ b/core/src/main/scala/spark/SparkFiles.java @@ -3,23 +3,23 @@ package spark; import java.io.File; /** - * Resolves paths to files added through `addFile(). + * Resolves paths to files added through `SparkContext.addFile()`. */ public class SparkFiles { private SparkFiles() {} /** - * Get the absolute path of a file added through `addFile()`. + * Get the absolute path of a file added through `SparkContext.addFile()`. */ public static String get(String filename) { return new File(getRootDirectory(), filename).getAbsolutePath(); } /** - * Get the root directory that contains files added through `addFile()`. + * Get the root directory that contains files added through `SparkContext.addFile()`. */ public static String getRootDirectory() { return SparkEnv.get().sparkFilesDir(); } -} \ No newline at end of file +} diff --git a/python/pyspark/context.py b/python/pyspark/context.py index b8d7dc05af..3e33776af0 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1,12 +1,15 @@ import os import atexit import shutil +import sys import tempfile +from threading import Lock from tempfile import NamedTemporaryFile from pyspark import accumulators from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast +from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched from pyspark.rdd import RDD @@ -27,6 +30,8 @@ class SparkContext(object): _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _takePartition = jvm.PythonRDD.takePartition _next_accum_id = 0 + _active_spark_context = None + _lock = Lock() def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -46,6 +51,11 @@ class SparkContext(object): Java object. Set 1 to disable batching or -1 to use an unlimited batch size. """ + with SparkContext._lock: + if SparkContext._active_spark_context: + raise ValueError("Cannot run multiple SparkContexts at once") + else: + SparkContext._active_spark_context = self self.master = master self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J @@ -75,6 +85,8 @@ class SparkContext(object): # Deploy any code dependencies specified in the constructor for path in (pyFiles or []): self.addPyFile(path) + SparkFiles._sc = self + sys.path.append(SparkFiles.getRootDirectory()) @property def defaultParallelism(self): @@ -85,17 +97,20 @@ class SparkContext(object): return self._jsc.sc().defaultParallelism() def __del__(self): - if self._jsc: - self._jsc.stop() - if self._accumulatorServer: - self._accumulatorServer.shutdown() + self.stop() def stop(self): """ Shut down the SparkContext. """ - self._jsc.stop() - self._jsc = None + if self._jsc: + self._jsc.stop() + self._jsc = None + if self._accumulatorServer: + self._accumulatorServer.shutdown() + self._accumulatorServer = None + with SparkContext._lock: + SparkContext._active_spark_context = None def parallelize(self, c, numSlices=None): """ diff --git a/python/pyspark/files.py b/python/pyspark/files.py index de1334f046..98f6a399cc 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -4,13 +4,15 @@ import os class SparkFiles(object): """ Resolves paths to files added through - L{addFile()}. + L{SparkContext.addFile()}. SparkFiles contains only classmethods; users should not create SparkFiles instances. """ _root_directory = None + _is_running_on_worker = False + _sc = None def __init__(self): raise NotImplementedError("Do not construct SparkFiles objects") @@ -18,7 +20,19 @@ class SparkFiles(object): @classmethod def get(cls, filename): """ - Get the absolute path of a file added through C{addFile()}. + Get the absolute path of a file added through C{SparkContext.addFile()}. """ - path = os.path.join(SparkFiles._root_directory, filename) + path = os.path.join(SparkFiles.getRootDirectory(), filename) return os.path.abspath(path) + + @classmethod + def getRootDirectory(cls): + """ + Get the root directory that contains files added through + C{SparkContext.addFile()}. + """ + if cls._is_running_on_worker: + return cls._root_directory + else: + # This will have to change if we support multiple SparkContexts: + return cls._sc.jvm.spark.SparkFiles.getRootDirectory() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 4d70ee4f12..46ab34f063 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -4,22 +4,26 @@ individual modules. """ import os import shutil +import sys from tempfile import NamedTemporaryFile import time import unittest from pyspark.context import SparkContext +from pyspark.files import SparkFiles from pyspark.java_gateway import SPARK_HOME class PySparkTestCase(unittest.TestCase): def setUp(self): + self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ self.sc = SparkContext('local[4]', class_name , batchSize=2) def tearDown(self): self.sc.stop() + sys.path = self._old_sys_path # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown self.sc.jvm.System.clearProperty("spark.master.port") @@ -84,6 +88,25 @@ class TestAddFile(PySparkTestCase): res = self.sc.parallelize(range(2)).map(func).first() self.assertEqual("Hello World!", res) + def test_add_file_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello.txt") + self.sc.addFile(path) + download_path = SparkFiles.get("hello.txt") + self.assertNotEqual(path, download_path) + with open(download_path) as test_file: + self.assertEquals("Hello World!\n", test_file.readline()) + + def test_add_py_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlibrary import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addFile(path) + from userlibrary import UserClass + self.assertEqual("Hello World!", UserClass().hello()) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4bf643da66..d33d6dd15f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -26,6 +26,7 @@ def main(): split_index = read_int(sys.stdin) spark_files_dir = load_pickle(read_with_length(sys.stdin)) SparkFiles._root_directory = spark_files_dir + SparkFiles._is_running_on_worker = True sys.path.append(spark_files_dir) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): diff --git a/python/test_support/hello.txt b/python/test_support/hello.txt new file mode 100755 index 0000000000..980a0d5f19 --- /dev/null +++ b/python/test_support/hello.txt @@ -0,0 +1 @@ +Hello World! -- cgit v1.2.3 From e1027ca6398fd5b1a99a2203df840911c4dccb27 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:22:11 -0800 Subject: Actually add CacheManager. --- core/src/main/scala/spark/CacheManager.scala | 65 ++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 core/src/main/scala/spark/CacheManager.scala (limited to 'core') diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala new file mode 100644 index 0000000000..a0b53fd9d6 --- /dev/null +++ b/core/src/main/scala/spark/CacheManager.scala @@ -0,0 +1,65 @@ +package spark + +import scala.collection.mutable.{ArrayBuffer, HashSet} +import spark.storage.{BlockManager, StorageLevel} + + +/** Spark class responsible for passing RDDs split contents to the BlockManager and making + sure a node doesn't load two copies of an RDD at once. + */ +private[spark] class CacheManager(blockManager: BlockManager) extends Logging { + private val loading = new HashSet[String] + + /** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */ + def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) + : Iterator[T] = { + val key = "rdd_%d_%d".format(rdd.id, split.index) + logInfo("Cache key is " + key) + blockManager.get(key) match { + case Some(cachedValues) => + // Split is in cache, so just return its values + logInfo("Found partition in cache!") + return cachedValues.asInstanceOf[Iterator[T]] + + case None => + // Mark the split as loading (unless someone else marks it first) + loading.synchronized { + if (loading.contains(key)) { + logInfo("Loading contains " + key + ", waiting...") + while (loading.contains(key)) { + try {loading.wait()} catch {case _ =>} + } + logInfo("Loading no longer contains " + key + ", so returning cached result") + // See whether someone else has successfully loaded it. The main way this would fail + // is for the RDD-level cache eviction policy if someone else has loaded the same RDD + // partition but we didn't want to make space for it. However, that case is unlikely + // because it's unlikely that two threads would work on the same RDD partition. One + // downside of the current code is that threads wait serially if this does happen. + blockManager.get(key) match { + case Some(values) => + return values.asInstanceOf[Iterator[T]] + case None => + logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") + loading.add(key) + } + } else { + loading.add(key) + } + } + try { + // If we got here, we have to load the split + val elements = new ArrayBuffer[Any] + logInfo("Computing partition " + split) + elements ++= rdd.compute(split, context) + // Try to put this block in the blockManager + blockManager.put(key, elements, storageLevel, true) + return elements.iterator.asInstanceOf[Iterator[T]] + } finally { + loading.synchronized { + loading.remove(key) + loading.notifyAll() + } + } + } + } +} -- cgit v1.2.3 From 88b9d240fda7ca34c08752dfa66797eecb6db872 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:40:38 -0800 Subject: Remove dead code in test. --- core/src/test/scala/spark/DistributedSuite.scala | 2 -- 1 file changed, 2 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 0d6b265e54..af66d33aa3 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -194,7 +194,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter DistributedSuite.amMaster = true sc = new SparkContext(clusterUrl, "test") val data = sc.parallelize(Seq(true, true), 2) - val singleton = sc.parallelize(Seq(true), 1) assert(data.count === 2) // force executors to start val masterId = SparkEnv.get.blockManager.blockManagerId assert(data.map(markNodeIfIdentity).collect.size === 2) @@ -207,7 +206,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc = new SparkContext(clusterUrl, "test") for (i <- 1 to 3) { val data = sc.parallelize(Seq(true, false), 2) - val singleton = sc.parallelize(Seq(false), 1) assert(data.count === 2) assert(data.map(markNodeIfIdentity).collect.size === 2) assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2) -- cgit v1.2.3 From be4a115a7ec7fb6ec0d34f1a1a1bb2c9bbe7600e Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:48:45 -0800 Subject: Clarify TODO. --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 740aec2e61..14a3ef8ad7 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -76,7 +76,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // sent with every task. When we detect a node failing, we note the current generation number // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask // results. - // TODO: Garbage collect information about failure generations when new stages start. + // TODO: Garbage collect information about failure generations when we know there are no more + // stray messages to detect. val failedGeneration = new HashMap[String, Long] val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done -- cgit v1.2.3 From e1985bfa04ad4583ac1f0f421cbe0182ce7c53df Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 21 Jan 2013 16:21:14 -0800 Subject: be sure to set class loader of kryo instances --- core/src/main/scala/spark/KryoSerializer.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 93d7327324..56919544e8 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -206,5 +206,8 @@ class KryoSerializer extends spark.serializer.Serializer with Logging { kryo } - def newInstance(): SerializerInstance = new KryoSerializerInstance(this) + def newInstance(): SerializerInstance = { + this.kryo.setClassLoader(Thread.currentThread().getContextClassLoader) + new KryoSerializerInstance(this) + } } -- cgit v1.2.3 From 5c7422292ecace947f78e5ebe97e83a355531af7 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:59:51 -0800 Subject: Remove more dead code from test. --- core/src/test/scala/spark/DistributedSuite.scala | 1 - 1 file changed, 1 deletion(-) (limited to 'core') diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index af66d33aa3..0487e06d12 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -218,7 +218,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc = new SparkContext(clusterUrl, "test") for (i <- 1 to 3) { val data = sc.parallelize(Seq(true, true), 2) - val singleton = sc.parallelize(Seq(false), 1) assert(data.count === 2) assert(data.map(markNodeIfIdentity).collect.size === 2) // This relies on mergeCombiners being used to perform the actual reduce for this -- cgit v1.2.3 From 1dd82743e09789f8fdae2f5628545c0cb9f79245 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 23 Jan 2013 13:07:27 -0800 Subject: Fix compile error due to cherry-pick --- core/src/main/scala/spark/KryoSerializer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 56919544e8..0bd73e936b 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -207,7 +207,7 @@ class KryoSerializer extends spark.serializer.Serializer with Logging { } def newInstance(): SerializerInstance = { - this.kryo.setClassLoader(Thread.currentThread().getContextClassLoader) + this.kryo.get().setClassLoader(Thread.currentThread().getContextClassLoader) new KryoSerializerInstance(this) } } -- cgit v1.2.3 From eb222b720647c9e92a867c591cc4914b9a6cb5c1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 15:29:02 -0800 Subject: Added pruntSplits method to RDD. --- core/src/main/scala/spark/RDD.scala | 10 +++++++++ .../main/scala/spark/rdd/SplitsPruningRDD.scala | 24 ++++++++++++++++++++++ core/src/test/scala/spark/RDDSuite.scala | 22 +++++++++++++------- 3 files changed, 49 insertions(+), 7 deletions(-) create mode 100644 core/src/main/scala/spark/rdd/SplitsPruningRDD.scala (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e0d2eabb1d..3d93ff33bb 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -40,6 +40,7 @@ import spark.rdd.MapPartitionsRDD import spark.rdd.MapPartitionsWithSplitRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD +import spark.rdd.SplitsPruningRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD import spark.storage.StorageLevel @@ -543,6 +544,15 @@ abstract class RDD[T: ClassManifest]( map(x => (f(x), x)) } + /** + * Prune splits (partitions) so Spark can avoid launching tasks on + * all splits. An example use case: If we know the RDD is partitioned by range, + * and the execution DAG has a filter on the key, we can avoid launching tasks + * on splits that don't have the range covering the key. + */ + def pruneSplits(splitsFilterFunc: Int => Boolean): RDD[T] = + new SplitsPruningRDD(this, splitsFilterFunc) + /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala new file mode 100644 index 0000000000..74e10265fc --- /dev/null +++ b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala @@ -0,0 +1,24 @@ +package spark.rdd + +import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} + +/** + * A RDD used to prune RDD splits so we can avoid launching tasks on + * all splits. An example use case: If we know the RDD is partitioned by range, + * and the execution DAG has a filter on the key, we can avoid launching tasks + * on splits that don't have the range covering the key. + */ +class SplitsPruningRDD[T: ClassManifest]( + prev: RDD[T], + @transient splitsFilterFunc: Int => Boolean) + extends RDD[T](prev) { + + @transient + val _splits: Array[Split] = prev.splits.filter(s => splitsFilterFunc(s.index)) + + override def compute(split: Split, context: TaskContext) = prev.iterator(split, context) + + override protected def getSplits = _splits + + override val partitioner = prev.partitioner +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index db217f8482..03aa2845f4 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -1,11 +1,9 @@ package spark import scala.collection.mutable.HashMap -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - +import org.scalatest.{BeforeAndAfter, FunSuite} +import spark.SparkContext._ import spark.rdd.CoalescedRDD -import SparkContext._ class RDDSuite extends FunSuite with BeforeAndAfter { @@ -104,7 +102,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { } test("caching with failures") { - sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val onlySplit = new Split { override def index: Int = 0 } var shouldFail = true val rdd = new RDD[Int](sc, Nil) { @@ -136,8 +134,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter { List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) // Check that the narrow dependency is also specified correctly - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4)) - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === + List(0, 1, 2, 3, 4)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === + List(5, 6, 7, 8, 9)) val coalesced2 = new CoalescedRDD(data, 3) assert(coalesced2.collect().toList === (1 to 10).toList) @@ -168,4 +168,12 @@ class RDDSuite extends FunSuite with BeforeAndAfter { nums.zip(sc.parallelize(1 to 4, 1)).collect() } } + + test("split pruning") { + sc = new SparkContext("local", "test") + val data = sc.parallelize(1 to 10, 10) + // Note that split number starts from 0, so > 8 means only 10th partition left. + val prunedData = data.pruneSplits(splitNum => splitNum > 8).collect + assert(prunedData.size == 1 && prunedData(0) == 10) + } } -- cgit v1.2.3 From c24b3819dd474e13d6098150c174b2e7e4bc6498 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 15:34:59 -0800 Subject: Added an extra assert for split size check. --- core/src/test/scala/spark/RDDSuite.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 03aa2845f4..ef74c99246 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -173,7 +173,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter { sc = new SparkContext("local", "test") val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. - val prunedData = data.pruneSplits(splitNum => splitNum > 8).collect - assert(prunedData.size == 1 && prunedData(0) == 10) + val prunedRdd = data.pruneSplits(splitNum => splitNum > 8) + assert(prunedRdd.splits.size == 1) + val prunedData = prunedRdd.collect + assert(prunedData.size == 1) + assert(prunedData(0) == 10) } } -- cgit v1.2.3 From 45cd50d5fe40869cdc237157e073cfb5ac47b27c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 16:06:58 -0800 Subject: Updated assert == to ===. --- core/src/test/scala/spark/RDDSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index ef74c99246..5a3a12dfff 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -174,9 +174,9 @@ class RDDSuite extends FunSuite with BeforeAndAfter { val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. val prunedRdd = data.pruneSplits(splitNum => splitNum > 8) - assert(prunedRdd.splits.size == 1) + assert(prunedRdd.splits.size === 1) val prunedData = prunedRdd.collect - assert(prunedData.size == 1) - assert(prunedData(0) == 10) + assert(prunedData.size === 1) + assert(prunedData(0) === 10) } } -- cgit v1.2.3 From 636e912f3289e422be9550752f5279d519062b75 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 21:21:55 -0800 Subject: Created a PruneDependency to properly assign dependency for SplitsPruningRDD. --- core/src/main/scala/spark/Dependency.scala | 24 +++++++++++++++++++--- .../main/scala/spark/rdd/SplitsPruningRDD.scala | 8 ++++---- 2 files changed, 25 insertions(+), 7 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index b85d2732db..7d5858e88e 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -5,6 +5,7 @@ package spark */ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable + /** * Base class for dependencies where each partition of the parent RDD is used by at most one * partition of the child RDD. Narrow dependencies allow for pipelined execution. @@ -12,12 +13,13 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { /** * Get the parent partitions for a child partition. - * @param outputPartition a partition of the child RDD + * @param partitionId a partition of the child RDD * @return the partitions of the parent RDD that the child partition depends upon */ - def getParents(outputPartition: Int): Seq[Int] + def getParents(partitionId: Int): Seq[Int] } + /** * Represents a dependency on the output of a shuffle stage. * @param shuffleId the shuffle id @@ -32,6 +34,7 @@ class ShuffleDependency[K, V]( val shuffleId: Int = rdd.context.newShuffleId() } + /** * Represents a one-to-one dependency between partitions of the parent and child RDDs. */ @@ -39,6 +42,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { override def getParents(partitionId: Int) = List(partitionId) } + /** * Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs. * @param rdd the parent RDD @@ -48,7 +52,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { */ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) extends NarrowDependency[T](rdd) { - + override def getParents(partitionId: Int) = { if (partitionId >= outStart && partitionId < outStart + length) { List(partitionId - outStart + inStart) @@ -57,3 +61,17 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) } } } + + +/** + * Represents a dependency between the SplitsPruningRDD and its parent. In this + * case, the child RDD contains a subset of splits of the parents'. + */ +class PruneDependency[T](rdd: RDD[T], @transient splitsFilterFunc: Int => Boolean) + extends NarrowDependency[T](rdd) { + + @transient + val splits: Array[Split] = rdd.splits.filter(s => splitsFilterFunc(s.index)) + + override def getParents(partitionId: Int) = List(splits(partitionId).index) +} diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala index 74e10265fc..7b44d85bb5 100644 --- a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala @@ -1,6 +1,6 @@ package spark.rdd -import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} +import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} /** * A RDD used to prune RDD splits so we can avoid launching tasks on @@ -11,12 +11,12 @@ import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} class SplitsPruningRDD[T: ClassManifest]( prev: RDD[T], @transient splitsFilterFunc: Int => Boolean) - extends RDD[T](prev) { + extends RDD[T](prev.context, List(new PruneDependency(prev, splitsFilterFunc))) { @transient - val _splits: Array[Split] = prev.splits.filter(s => splitsFilterFunc(s.index)) + val _splits: Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits - override def compute(split: Split, context: TaskContext) = prev.iterator(split, context) + override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) override protected def getSplits = _splits -- cgit v1.2.3 From 81004b967e838fca0790727a3fea5a265ddbc69a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 21:54:27 -0800 Subject: Marked prev RDD as transient in SplitsPruningRDD. --- core/src/main/scala/spark/rdd/SplitsPruningRDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala index 7b44d85bb5..9b1a210ba3 100644 --- a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala @@ -9,7 +9,7 @@ import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} * on splits that don't have the range covering the key. */ class SplitsPruningRDD[T: ClassManifest]( - prev: RDD[T], + @transient prev: RDD[T], @transient splitsFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, splitsFilterFunc))) { @@ -20,5 +20,5 @@ class SplitsPruningRDD[T: ClassManifest]( override protected def getSplits = _splits - override val partitioner = prev.partitioner + override val partitioner = firstParent[T].partitioner } -- cgit v1.2.3 From eedc542a0276a5248c81446ee84f56d691e5f488 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 22:14:23 -0800 Subject: Removed pruneSplits method in RDD and renamed SplitsPruningRDD to PartitionPruningRDD. --- core/src/main/scala/spark/RDD.scala | 10 --------- .../main/scala/spark/rdd/PartitionPruningRDD.scala | 24 ++++++++++++++++++++++ .../main/scala/spark/rdd/SplitsPruningRDD.scala | 24 ---------------------- core/src/test/scala/spark/RDDSuite.scala | 6 +++--- 4 files changed, 27 insertions(+), 37 deletions(-) create mode 100644 core/src/main/scala/spark/rdd/PartitionPruningRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/SplitsPruningRDD.scala (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 3d93ff33bb..e0d2eabb1d 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -40,7 +40,6 @@ import spark.rdd.MapPartitionsRDD import spark.rdd.MapPartitionsWithSplitRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD -import spark.rdd.SplitsPruningRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD import spark.storage.StorageLevel @@ -544,15 +543,6 @@ abstract class RDD[T: ClassManifest]( map(x => (f(x), x)) } - /** - * Prune splits (partitions) so Spark can avoid launching tasks on - * all splits. An example use case: If we know the RDD is partitioned by range, - * and the execution DAG has a filter on the key, we can avoid launching tasks - * on splits that don't have the range covering the key. - */ - def pruneSplits(splitsFilterFunc: Int => Boolean): RDD[T] = - new SplitsPruningRDD(this, splitsFilterFunc) - /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala new file mode 100644 index 0000000000..3048949ef2 --- /dev/null +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -0,0 +1,24 @@ +package spark.rdd + +import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} + +/** + * A RDD used to prune RDD partitions/splits so we can avoid launching tasks on + * all partitions. An example use case: If we know the RDD is partitioned by range, + * and the execution DAG has a filter on the key, we can avoid launching tasks + * on partitions that don't have the range covering the key. + */ +class PartitionPruningRDD[T: ClassManifest]( + @transient prev: RDD[T], + @transient partitionFilterFunc: Int => Boolean) + extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { + + @transient + val partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits + + override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) + + override protected def getSplits = partitions_ + + override val partitioner = firstParent[T].partitioner +} diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala deleted file mode 100644 index 9b1a210ba3..0000000000 --- a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala +++ /dev/null @@ -1,24 +0,0 @@ -package spark.rdd - -import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} - -/** - * A RDD used to prune RDD splits so we can avoid launching tasks on - * all splits. An example use case: If we know the RDD is partitioned by range, - * and the execution DAG has a filter on the key, we can avoid launching tasks - * on splits that don't have the range covering the key. - */ -class SplitsPruningRDD[T: ClassManifest]( - @transient prev: RDD[T], - @transient splitsFilterFunc: Int => Boolean) - extends RDD[T](prev.context, List(new PruneDependency(prev, splitsFilterFunc))) { - - @transient - val _splits: Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits - - override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) - - override protected def getSplits = _splits - - override val partitioner = firstParent[T].partitioner -} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 5a3a12dfff..73846131a9 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -3,7 +3,7 @@ package spark import scala.collection.mutable.HashMap import org.scalatest.{BeforeAndAfter, FunSuite} import spark.SparkContext._ -import spark.rdd.CoalescedRDD +import spark.rdd.{CoalescedRDD, PartitionPruningRDD} class RDDSuite extends FunSuite with BeforeAndAfter { @@ -169,11 +169,11 @@ class RDDSuite extends FunSuite with BeforeAndAfter { } } - test("split pruning") { + test("partition pruning") { sc = new SparkContext("local", "test") val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. - val prunedRdd = data.pruneSplits(splitNum => splitNum > 8) + val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) assert(prunedRdd.splits.size === 1) val prunedData = prunedRdd.collect assert(prunedData.size === 1) -- cgit v1.2.3 From c109f29c97c9606dee45e6300d01a272dbb560aa Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 22:22:03 -0800 Subject: Updated PruneDependency to change "split" to "partition". --- core/src/main/scala/spark/Dependency.scala | 10 +++++----- core/src/main/scala/spark/rdd/PartitionPruningRDD.scala | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index 7d5858e88e..647aee6eb5 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -64,14 +64,14 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) /** - * Represents a dependency between the SplitsPruningRDD and its parent. In this - * case, the child RDD contains a subset of splits of the parents'. + * Represents a dependency between the PartitionPruningRDD and its parent. In this + * case, the child RDD contains a subset of partitions of the parents'. */ -class PruneDependency[T](rdd: RDD[T], @transient splitsFilterFunc: Int => Boolean) +class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) extends NarrowDependency[T](rdd) { @transient - val splits: Array[Split] = rdd.splits.filter(s => splitsFilterFunc(s.index)) + val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index)) - override def getParents(partitionId: Int) = List(splits(partitionId).index) + override def getParents(partitionId: Int) = List(partitions(partitionId).index) } diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 3048949ef2..787b59ae8c 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -14,7 +14,7 @@ class PartitionPruningRDD[T: ClassManifest]( extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { @transient - val partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits + val partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) -- cgit v1.2.3 From 67a43bc7e622e4dd9d53ccf80b441740d6ff4df5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 23:06:52 -0800 Subject: Added a clearDependencies method in PartitionPruningRDD. --- core/src/main/scala/spark/rdd/PartitionPruningRDD.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 787b59ae8c..97dd37950e 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -14,11 +14,16 @@ class PartitionPruningRDD[T: ClassManifest]( extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { @transient - val partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions + var partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) override protected def getSplits = partitions_ override val partitioner = firstParent[T].partitioner + + override def clearDependencies() { + super.clearDependencies() + partitions_ = null + } } -- cgit v1.2.3 From 230bda204778e6f3c0f5a20ad341f643146d97cb Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 15 Jan 2013 14:01:19 -0600 Subject: Add LocalSparkContext to manage common sc variable. --- core/src/test/scala/spark/AccumulatorSuite.scala | 32 ++-------- core/src/test/scala/spark/BroadcastSuite.scala | 14 +---- core/src/test/scala/spark/CheckpointSuite.scala | 19 ++---- .../src/test/scala/spark/ClosureCleanerSuite.scala | 73 ++++++++++------------ core/src/test/scala/spark/DistributedSuite.scala | 23 ++----- core/src/test/scala/spark/FailureSuite.scala | 14 +---- core/src/test/scala/spark/FileServerSuite.scala | 16 ++--- core/src/test/scala/spark/FileSuite.scala | 16 +---- core/src/test/scala/spark/LocalSparkContext.scala | 41 ++++++++++++ .../test/scala/spark/MapOutputTrackerSuite.scala | 7 +-- core/src/test/scala/spark/PartitioningSuite.scala | 15 +---- core/src/test/scala/spark/PipedRDDSuite.scala | 16 +---- core/src/test/scala/spark/RDDSuite.scala | 14 +---- core/src/test/scala/spark/ShuffleSuite.scala | 14 +---- core/src/test/scala/spark/SortingSuite.scala | 13 +--- core/src/test/scala/spark/ThreadingSuite.scala | 14 +---- .../scala/spark/scheduler/TaskContextSuite.scala | 14 +---- 17 files changed, 109 insertions(+), 246 deletions(-) create mode 100644 core/src/test/scala/spark/LocalSparkContext.scala (limited to 'core') diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala index d8be99dde7..78d64a44ae 100644 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -1,6 +1,5 @@ package spark -import org.scalatest.BeforeAndAfter import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers import collection.mutable @@ -9,18 +8,7 @@ import scala.math.exp import scala.math.signum import spark.SparkContext._ -class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { - - var sc: SparkContext = null - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext { test ("basic accumulation"){ sc = new SparkContext("local", "test") @@ -53,10 +41,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter for (i <- 1 to maxI) { v should contain(i) } - sc.stop() - sc = null - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + resetSparkContext() } } @@ -86,10 +71,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter x => acc.value += x } } should produce [SparkException] - sc.stop() - sc = null - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + resetSparkContext() } } @@ -115,10 +97,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter bufferAcc.value should contain(i) mapAcc.value should contain (i -> i.toString) } - sc.stop() - sc = null - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + resetSparkContext() } } @@ -134,8 +113,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter x => acc.localValue ++= x } acc.value should be ( (0 to maxI).toSet) - sc.stop() - sc = null + resetSparkContext() } } diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala index 2d3302f0aa..362a31fb0d 100644 --- a/core/src/test/scala/spark/BroadcastSuite.scala +++ b/core/src/test/scala/spark/BroadcastSuite.scala @@ -1,20 +1,8 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter -class BroadcastSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class BroadcastSuite extends FunSuite with LocalSparkContext { test("basic broadcast") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 51573254ca..33c317720c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -1,34 +1,27 @@ package spark -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.FunSuite import java.io.File import spark.rdd._ import spark.SparkContext._ import storage.StorageLevel -class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { +class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { initLogging() - var sc: SparkContext = _ var checkpointDir: File = _ val partitioner = new HashPartitioner(2) - before { + override def beforeEach() { + super.beforeEach() checkpointDir = File.createTempFile("temp", "") checkpointDir.delete() - sc = new SparkContext("local", "test") sc.setCheckpointDir(checkpointDir.toString) } - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - + override def afterEach() { + super.afterEach() if (checkpointDir != null) { checkpointDir.delete() } diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala index dfa2de80e6..b2d0dd4627 100644 --- a/core/src/test/scala/spark/ClosureCleanerSuite.scala +++ b/core/src/test/scala/spark/ClosureCleanerSuite.scala @@ -3,6 +3,7 @@ package spark import java.io.NotSerializableException import org.scalatest.FunSuite +import spark.LocalSparkContext._ import SparkContext._ class ClosureCleanerSuite extends FunSuite { @@ -43,13 +44,10 @@ object TestObject { def run(): Int = { var nonSer = new NonSerializable var x = 5 - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - val answer = nums.map(_ + x).reduce(_ + _) - sc.stop() - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - return answer + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + x).reduce(_ + _) + } } } @@ -60,11 +58,10 @@ class TestClass extends Serializable { def run(): Int = { var nonSer = new NonSerializable - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - val answer = nums.map(_ + getX).reduce(_ + _) - sc.stop() - return answer + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + getX).reduce(_ + _) + } } } @@ -73,11 +70,10 @@ class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { def run(): Int = { var nonSer = new NonSerializable - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - val answer = nums.map(_ + getX).reduce(_ + _) - sc.stop() - return answer + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + getX).reduce(_ + _) + } } } @@ -89,11 +85,10 @@ class TestClassWithoutFieldAccess { def run(): Int = { var nonSer2 = new NonSerializable var x = 5 - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - val answer = nums.map(_ + x).reduce(_ + _) - sc.stop() - return answer + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + x).reduce(_ + _) + } } } @@ -102,16 +97,16 @@ object TestObjectWithNesting { def run(): Int = { var nonSer = new NonSerializable var answer = 0 - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - var y = 1 - for (i <- 1 to 4) { - var nonSer2 = new NonSerializable - var x = i - answer += nums.map(_ + x + y).reduce(_ + _) + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + var y = 1 + for (i <- 1 to 4) { + var nonSer2 = new NonSerializable + var x = i + answer += nums.map(_ + x + y).reduce(_ + _) + } + answer } - sc.stop() - return answer } } @@ -121,14 +116,14 @@ class TestClassWithNesting(val y: Int) extends Serializable { def run(): Int = { var nonSer = new NonSerializable var answer = 0 - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - for (i <- 1 to 4) { - var nonSer2 = new NonSerializable - var x = i - answer += nums.map(_ + x + getY).reduce(_ + _) + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + for (i <- 1 to 4) { + var nonSer2 = new NonSerializable + var x = i + answer += nums.map(_ + x + getY).reduce(_ + _) + } + answer } - sc.stop() - return answer } } diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index cacc2796b6..83a2a549a9 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -15,41 +15,28 @@ import scala.collection.mutable.ArrayBuffer import SparkContext._ import storage.StorageLevel -class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { +class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" - @transient var sc: SparkContext = _ - after { - if (sc != null) { - sc.stop() - sc = null - } System.clearProperty("spark.reducer.maxMbInFlight") System.clearProperty("spark.storage.memoryFraction") - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") } test("local-cluster format") { sc = new SparkContext("local-cluster[2,1,512]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) - sc.stop() - System.clearProperty("spark.master.port") + resetSparkContext() sc = new SparkContext("local-cluster[2 , 1 , 512]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) - sc.stop() - System.clearProperty("spark.master.port") + resetSparkContext() sc = new SparkContext("local-cluster[2, 1, 512]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) - sc.stop() - System.clearProperty("spark.master.port") + resetSparkContext() sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) - sc.stop() - System.clearProperty("spark.master.port") - sc = null + resetSparkContext() } test("simple groupByKey") { diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala index a3454f25f6..8c1445a465 100644 --- a/core/src/test/scala/spark/FailureSuite.scala +++ b/core/src/test/scala/spark/FailureSuite.scala @@ -1,7 +1,6 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import org.scalatest.prop.Checkers import scala.collection.mutable.ArrayBuffer @@ -23,18 +22,7 @@ object FailureSuiteState { } } -class FailureSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class FailureSuite extends FunSuite with LocalSparkContext { // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala index b4283d9604..8215cbde02 100644 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -2,17 +2,16 @@ package spark import com.google.common.io.Files import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import java.io.{File, PrintWriter, FileReader, BufferedReader} import SparkContext._ -class FileServerSuite extends FunSuite with BeforeAndAfter { +class FileServerSuite extends FunSuite with LocalSparkContext { - @transient var sc: SparkContext = _ @transient var tmpFile : File = _ @transient var testJarFile : File = _ - before { + override def beforeEach() { + super.beforeEach() // Create a sample text file val tmpdir = new File(Files.createTempDir(), "test") tmpdir.mkdir() @@ -22,17 +21,12 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { pw.close() } - after { - if (sc != null) { - sc.stop() - sc = null - } + override def afterEach() { + super.afterEach() // Clean up downloaded file if (tmpFile.exists) { tmpFile.delete() } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") } test("Distributing files locally") { diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index 554bea53a9..91b48c7456 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -6,24 +6,12 @@ import scala.io.Source import com.google.common.io.Files import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import org.apache.hadoop.io._ import SparkContext._ -class FileSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } - +class FileSuite extends FunSuite with LocalSparkContext { + test("text files") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala new file mode 100644 index 0000000000..b5e31ddae3 --- /dev/null +++ b/core/src/test/scala/spark/LocalSparkContext.scala @@ -0,0 +1,41 @@ +package spark + +import org.scalatest.Suite +import org.scalatest.BeforeAndAfterEach + +/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ +trait LocalSparkContext extends BeforeAndAfterEach { self: Suite => + + @transient var sc: SparkContext = _ + + override def afterEach() { + resetSparkContext() + super.afterEach() + } + + def resetSparkContext() = { + if (sc != null) { + LocalSparkContext.stop(sc) + sc = null + } + } + +} + +object LocalSparkContext { + def stop(sc: SparkContext) { + sc.stop() + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + } + + /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ + def withSpark[T](sc: SparkContext)(f: SparkContext => T) = { + try { + f(sc) + } finally { + stop(sc) + } + } + +} \ No newline at end of file diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index d3dd3a8fa4..774bbd65b1 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -1,17 +1,13 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import akka.actor._ import spark.scheduler.MapStatus import spark.storage.BlockManagerId import spark.util.AkkaUtils -class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { - after { - System.clearProperty("spark.master.port") - } +class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("compressSize") { assert(MapOutputTracker.compressSize(0L) === 0) @@ -81,7 +77,6 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { } test("remote fetch") { - System.clearProperty("spark.master.host") val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) System.setProperty("spark.master.port", boundPort.toString) diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index eb3c8f238f..af1107cd19 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -1,25 +1,12 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import scala.collection.mutable.ArrayBuffer import SparkContext._ -class PartitioningSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if(sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } - +class PartitioningSuite extends FunSuite with LocalSparkContext { test("HashPartitioner equality") { val p2 = new HashPartitioner(2) diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala index 9b84b29227..a6344edf8f 100644 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -1,21 +1,9 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import SparkContext._ -class PipedRDDSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class PipedRDDSuite extends FunSuite with LocalSparkContext { test("basic pipe") { sc = new SparkContext("local", "test") @@ -51,5 +39,3 @@ class PipedRDDSuite extends FunSuite with BeforeAndAfter { } } - - diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index db217f8482..592427e97a 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -2,23 +2,11 @@ package spark import scala.collection.mutable.HashMap import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import spark.rdd.CoalescedRDD import SparkContext._ -class RDDSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class RDDSuite extends FunSuite with LocalSparkContext { test("basic operations") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index bebb8ebe86..3493b9511f 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -3,7 +3,6 @@ package spark import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import org.scalatest.matchers.ShouldMatchers import org.scalatest.prop.Checkers import org.scalacheck.Arbitrary._ @@ -15,18 +14,7 @@ import com.google.common.io.Files import spark.rdd.ShuffledRDD import spark.SparkContext._ -class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { test("groupByKey") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala index 1ad11ff4c3..edb8c839fc 100644 --- a/core/src/test/scala/spark/SortingSuite.scala +++ b/core/src/test/scala/spark/SortingSuite.scala @@ -5,18 +5,7 @@ import org.scalatest.BeforeAndAfter import org.scalatest.matchers.ShouldMatchers import SparkContext._ -class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with Logging { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging { test("sortByKey") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/ThreadingSuite.scala b/core/src/test/scala/spark/ThreadingSuite.scala index e9b1837d89..ff315b6693 100644 --- a/core/src/test/scala/spark/ThreadingSuite.scala +++ b/core/src/test/scala/spark/ThreadingSuite.scala @@ -22,19 +22,7 @@ object ThreadingSuiteState { } } -class ThreadingSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if(sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } - +class ThreadingSuite extends FunSuite with LocalSparkContext { test("accessing SparkContext form a different thread") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala index ba6f8b588f..a5db7103f5 100644 --- a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala @@ -6,19 +6,9 @@ import spark.TaskContext import spark.RDD import spark.SparkContext import spark.Split +import spark.LocalSparkContext -class TaskContextSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { test("Calls executeOnCompleteCallbacks after failure") { var completed = false -- cgit v1.2.3 From b6fc6e67521e8a9a5291693cce3dc766da244395 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 24 Jan 2013 14:28:05 -0800 Subject: SPARK-541: Adding a warning for invalid Master URL Right now Spark silently parses master URL's which do not match any known regex as a Mesos URL. The Mesos error message when an invalid URL gets passed is really confusing, so this warns the user when the implicit conversion is happening. --- core/src/main/scala/spark/SparkContext.scala | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 66bdbe7cda..bc9fdee8b6 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -112,6 +112,8 @@ class SparkContext( val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters val SPARK_REGEX = """(spark://.*)""".r + //Regular expression for connection to Mesos cluster + val MESOS_REGEX = """(mesos://.*)""".r master match { case "local" => @@ -152,6 +154,9 @@ class SparkContext( scheduler case _ => + if (MESOS_REGEX.findFirstIn(master).isEmpty) { + logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) + } MesosNativeLibrary.load() val scheduler = new ClusterScheduler(this) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean -- cgit v1.2.3 From 7dfb82a992d47491174d7929e31351d26cadfcda Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 15:25:41 -0600 Subject: Replace old 'master' term with 'driver'. --- bagel/src/test/scala/bagel/BagelSuite.scala | 2 +- core/src/main/scala/spark/MapOutputTracker.scala | 10 +-- core/src/main/scala/spark/SparkContext.scala | 20 +++--- core/src/main/scala/spark/SparkEnv.scala | 22 +++---- .../spark/broadcast/BitTorrentBroadcast.scala | 24 +++---- .../src/main/scala/spark/broadcast/Broadcast.scala | 6 +- .../scala/spark/broadcast/BroadcastFactory.scala | 4 +- .../main/scala/spark/broadcast/HttpBroadcast.scala | 6 +- .../main/scala/spark/broadcast/MultiTracker.scala | 35 +++++----- .../main/scala/spark/broadcast/TreeBroadcast.scala | 52 +++++++-------- .../scala/spark/deploy/LocalSparkCluster.scala | 34 +++++----- .../scala/spark/deploy/client/ClientListener.scala | 4 +- .../main/scala/spark/deploy/master/JobInfo.scala | 2 +- .../main/scala/spark/deploy/master/Master.scala | 18 +++--- .../spark/executor/StandaloneExecutorBackend.scala | 26 ++++---- .../cluster/SparkDeploySchedulerBackend.scala | 33 +++++----- .../cluster/StandaloneClusterMessage.scala | 8 +-- .../cluster/StandaloneSchedulerBackend.scala | 74 +++++++++++----------- .../mesos/CoarseMesosSchedulerBackend.scala | 6 +- .../scala/spark/storage/BlockManagerMaster.scala | 69 ++++++++++---------- .../main/scala/spark/storage/ThreadingTest.scala | 6 +- core/src/test/scala/spark/JavaAPISuite.java | 2 +- core/src/test/scala/spark/LocalSparkContext.scala | 2 +- .../test/scala/spark/MapOutputTrackerSuite.scala | 2 +- docs/configuration.md | 12 ++-- python/pyspark/tests.py | 2 +- repl/src/test/scala/spark/repl/ReplSuite.scala | 2 +- .../streaming/dstream/NetworkInputDStream.scala | 4 +- .../test/java/spark/streaming/JavaAPISuite.java | 2 +- .../spark/streaming/BasicOperationsSuite.scala | 2 +- .../scala/spark/streaming/CheckpointSuite.scala | 2 +- .../test/scala/spark/streaming/FailureSuite.scala | 2 +- .../scala/spark/streaming/InputStreamsSuite.scala | 2 +- .../spark/streaming/WindowOperationsSuite.scala | 2 +- 34 files changed, 248 insertions(+), 251 deletions(-) (limited to 'core') diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala index ca59f46843..3c2f9c4616 100644 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/bagel/BagelSuite.scala @@ -23,7 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter { sc = null } // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } test("halting by voting") { diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index ac02f3363a..d4f5164f7d 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -38,10 +38,7 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac } } -private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging { - val ip: String = System.getProperty("spark.master.host", "localhost") - val port: Int = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "MapOutputTracker" +private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolean) extends Logging { val timeout = 10.seconds @@ -56,11 +53,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea var cacheGeneration = generation val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] - var trackerActor: ActorRef = if (isMaster) { + val actorName: String = "MapOutputTracker" + var trackerActor: ActorRef = if (isDriver) { val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName) logInfo("Registered MapOutputTrackerActor actor") actor } else { + val ip = System.getProperty("spark.driver.host", "localhost") + val port = System.getProperty("spark.driver.port", "7077").toInt val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) actorSystem.actorFor(url) } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bc9fdee8b6..d4991cb1e0 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -66,20 +66,20 @@ class SparkContext( // Ensure logging is initialized before we spawn any threads initLogging() - // Set Spark master host and port system properties - if (System.getProperty("spark.master.host") == null) { - System.setProperty("spark.master.host", Utils.localIpAddress) + // Set Spark driver host and port system properties + if (System.getProperty("spark.driver.host") == null) { + System.setProperty("spark.driver.host", Utils.localIpAddress) } - if (System.getProperty("spark.master.port") == null) { - System.setProperty("spark.master.port", "0") + if (System.getProperty("spark.driver.port") == null) { + System.setProperty("spark.driver.port", "0") } private val isLocal = (master == "local" || master.startsWith("local[")) // Create the Spark execution environment (cache, map output tracker, etc) private[spark] val env = SparkEnv.createFromSystemProperties( - System.getProperty("spark.master.host"), - System.getProperty("spark.master.port").toInt, + System.getProperty("spark.driver.host"), + System.getProperty("spark.driver.port").toInt, true, isLocal) SparkEnv.set(env) @@ -396,14 +396,14 @@ class SparkContext( /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `+=` method. Only the driver can access the accumulator's `value`. */ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = new Accumulator(initialValue, param) /** * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. - * Only the master can access the accumuable's `value`. + * Only the driver can access the accumuable's `value`. * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ @@ -530,7 +530,7 @@ class SparkContext( /** * Run a function on a given set of partitions in an RDD and return the results. This is the main * entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies - * whether the scheduler can run the computation on the master rather than shipping it out to the + * whether the scheduler can run the computation on the driver rather than shipping it out to the * cluster, for short actions like first(). */ def runJob[T, U: ClassManifest]( diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 2a7a8af83d..4034af610c 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -60,15 +60,15 @@ object SparkEnv extends Logging { def createFromSystemProperties( hostname: String, port: Int, - isMaster: Boolean, + isDriver: Boolean, isLocal: Boolean ) : SparkEnv = { val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port) - // Bit of a hack: If this is the master and our port was 0 (meaning bind to any free port), - // figure out which port number Akka actually bound to and set spark.master.port to it. - if (isMaster && port == 0) { - System.setProperty("spark.master.port", boundPort.toString) + // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port), + // figure out which port number Akka actually bound to and set spark.driver.port to it. + if (isDriver && port == 0) { + System.setProperty("spark.driver.port", boundPort.toString) } val classLoader = Thread.currentThread.getContextClassLoader @@ -82,22 +82,22 @@ object SparkEnv extends Logging { val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") - val masterIp: String = System.getProperty("spark.master.host", "localhost") - val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt + val driverIp: String = System.getProperty("spark.driver.host", "localhost") + val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt val blockManagerMaster = new BlockManagerMaster( - actorSystem, isMaster, isLocal, masterIp, masterPort) + actorSystem, isDriver, isLocal, driverIp, driverPort) val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer) val connectionManager = blockManager.connectionManager - val broadcastManager = new BroadcastManager(isMaster) + val broadcastManager = new BroadcastManager(isDriver) val closureSerializer = instantiateClass[Serializer]( "spark.closure.serializer", "spark.JavaSerializer") val cacheManager = new CacheManager(blockManager) - val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster) + val mapOutputTracker = new MapOutputTracker(actorSystem, isDriver) val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") @@ -109,7 +109,7 @@ object SparkEnv extends Logging { // Set the sparkFiles directory, used when downloading dependencies. In local mode, // this is a temporary directory; in distributed mode, this is the executor's current working // directory. - val sparkFilesDir: String = if (isMaster) { + val sparkFilesDir: String = if (isDriver) { Utils.createTempDir().getAbsolutePath } else { "." diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala index 386f505f2a..adcb2d2415 100644 --- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala @@ -31,7 +31,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: @transient var totalBlocks = -1 @transient var hasBlocks = new AtomicInteger(0) - // Used ONLY by Master to track how many unique blocks have been sent out + // Used ONLY by driver to track how many unique blocks have been sent out @transient var sentBlocks = new AtomicInteger(0) @transient var listenPortLock = new Object @@ -42,7 +42,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: @transient var serveMR: ServeMultipleRequests = null - // Used only in Master + // Used only in driver @transient var guideMR: GuideMultipleRequests = null // Used only in Workers @@ -99,14 +99,14 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: } // Must always come AFTER listenPort is created - val masterSource = + val driverSource = SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) hasBlocksBitVector.synchronized { - masterSource.hasBlocksBitVector = hasBlocksBitVector + driverSource.hasBlocksBitVector = hasBlocksBitVector } // In the beginning, this is the only known source to Guide - listOfSources += masterSource + listOfSources += driverSource // Register with the Tracker MultiTracker.registerBroadcast(id, @@ -122,7 +122,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: case None => logInfo("Started reading broadcast variable " + id) - // Initializing everything because Master will only send null/0 values + // Initializing everything because driver will only send null/0 values // Only the 1st worker in a node can be here. Others will get from cache initializeWorkerVariables() @@ -151,7 +151,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: } } - // Initialize variables in the worker node. Master sends everything as 0/null + // Initialize variables in the worker node. Driver sends everything as 0/null private def initializeWorkerVariables() { arrayOfBlocks = null hasBlocksBitVector = null @@ -248,7 +248,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: // Receive source information from Guide var suitableSources = oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] - logDebug("Received suitableSources from Master " + suitableSources) + logDebug("Received suitableSources from Driver " + suitableSources) addToListOfSources(suitableSources) @@ -532,7 +532,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: oosSource.writeObject(blockToAskFor) oosSource.flush() - // CHANGED: Master might send some other block than the one + // CHANGED: Driver might send some other block than the one // requested to ensure fast spreading of all blocks. val recvStartTime = System.currentTimeMillis val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] @@ -982,9 +982,9 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: // Receive which block to send var blockToSend = ois.readObject.asInstanceOf[Int] - // If it is master AND at least one copy of each block has not been + // If it is driver AND at least one copy of each block has not been // sent out already, MODIFY blockToSend - if (MultiTracker.isMaster && sentBlocks.get < totalBlocks) { + if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) { blockToSend = sentBlocks.getAndIncrement } @@ -1031,7 +1031,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: private[spark] class BitTorrentBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) } + def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new BitTorrentBroadcast[T](value_, isLocal, id) diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index 2ffe7f741d..415bde5d67 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -15,7 +15,7 @@ abstract class Broadcast[T](private[spark] val id: Long) extends Serializable { } private[spark] -class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable { +class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable { private var initialized = false private var broadcastFactory: BroadcastFactory = null @@ -33,7 +33,7 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isMaster) + broadcastFactory.initialize(isDriver) initialized = true } @@ -49,5 +49,5 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl def newBroadcast[T](value_ : T, isLocal: Boolean) = broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) - def isMaster = isMaster_ + def isDriver = _isDriver } diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala index ab6d302827..5c6184c3c7 100644 --- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala @@ -7,7 +7,7 @@ package spark.broadcast * entire Spark job. */ private[spark] trait BroadcastFactory { - def initialize(isMaster: Boolean): Unit - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] + def initialize(isDriver: Boolean): Unit + def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 8e490e6bad..7e30b8f7d2 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -48,7 +48,7 @@ extends Broadcast[T](id) with Logging with Serializable { } private[spark] class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) } + def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new HttpBroadcast[T](value_, isLocal, id) @@ -69,12 +69,12 @@ private object HttpBroadcast extends Logging { private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup) - def initialize(isMaster: Boolean) { + def initialize(isDriver: Boolean) { synchronized { if (!initialized) { bufferSize = System.getProperty("spark.buffer.size", "65536").toInt compress = System.getProperty("spark.broadcast.compress", "true").toBoolean - if (isMaster) { + if (isDriver) { createServer() } serverUri = System.getProperty("spark.httpBroadcast.uri") diff --git a/core/src/main/scala/spark/broadcast/MultiTracker.scala b/core/src/main/scala/spark/broadcast/MultiTracker.scala index 5e76dedb94..3fd77af73f 100644 --- a/core/src/main/scala/spark/broadcast/MultiTracker.scala +++ b/core/src/main/scala/spark/broadcast/MultiTracker.scala @@ -23,25 +23,24 @@ extends Logging { var ranGen = new Random private var initialized = false - private var isMaster_ = false + private var _isDriver = false private var stopBroadcast = false private var trackMV: TrackMultipleValues = null - def initialize(isMaster__ : Boolean) { + def initialize(__isDriver: Boolean) { synchronized { if (!initialized) { + _isDriver = __isDriver - isMaster_ = isMaster__ - - if (isMaster) { + if (isDriver) { trackMV = new TrackMultipleValues trackMV.setDaemon(true) trackMV.start() - // Set masterHostAddress to the master's IP address for the slaves to read - System.setProperty("spark.MultiTracker.MasterHostAddress", Utils.localIpAddress) + // Set DriverHostAddress to the driver's IP address for the slaves to read + System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress) } initialized = true @@ -54,10 +53,10 @@ extends Logging { } // Load common parameters - private var MasterHostAddress_ = System.getProperty( - "spark.MultiTracker.MasterHostAddress", "") - private var MasterTrackerPort_ = System.getProperty( - "spark.broadcast.masterTrackerPort", "11111").toInt + private var DriverHostAddress_ = System.getProperty( + "spark.MultiTracker.DriverHostAddress", "") + private var DriverTrackerPort_ = System.getProperty( + "spark.broadcast.driverTrackerPort", "11111").toInt private var BlockSize_ = System.getProperty( "spark.broadcast.blockSize", "4096").toInt * 1024 private var MaxRetryCount_ = System.getProperty( @@ -91,11 +90,11 @@ extends Logging { private var EndGameFraction_ = System.getProperty( "spark.broadcast.endGameFraction", "0.95").toDouble - def isMaster = isMaster_ + def isDriver = _isDriver // Common config params - def MasterHostAddress = MasterHostAddress_ - def MasterTrackerPort = MasterTrackerPort_ + def DriverHostAddress = DriverHostAddress_ + def DriverTrackerPort = DriverTrackerPort_ def BlockSize = BlockSize_ def MaxRetryCount = MaxRetryCount_ @@ -123,7 +122,7 @@ extends Logging { var threadPool = Utils.newDaemonCachedThreadPool() var serverSocket: ServerSocket = null - serverSocket = new ServerSocket(MasterTrackerPort) + serverSocket = new ServerSocket(DriverTrackerPort) logInfo("TrackMultipleValues started at " + serverSocket) try { @@ -235,7 +234,7 @@ extends Logging { try { // Connect to the tracker to find out GuideInfo clientSocketToTracker = - new Socket(MultiTracker.MasterHostAddress, MultiTracker.MasterTrackerPort) + new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort) oosTracker = new ObjectOutputStream(clientSocketToTracker.getOutputStream) oosTracker.flush() @@ -276,7 +275,7 @@ extends Logging { } def registerBroadcast(id: Long, gInfo: SourceInfo) { - val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) + val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) val oosST = new ObjectOutputStream(socket.getOutputStream) oosST.flush() val oisST = new ObjectInputStream(socket.getInputStream) @@ -303,7 +302,7 @@ extends Logging { } def unregisterBroadcast(id: Long) { - val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) + val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) val oosST = new ObjectOutputStream(socket.getOutputStream) oosST.flush() val oisST = new ObjectInputStream(socket.getInputStream) diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala index f573512835..c55c476117 100644 --- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala @@ -98,7 +98,7 @@ extends Broadcast[T](id) with Logging with Serializable { case None => logInfo("Started reading broadcast variable " + id) - // Initializing everything because Master will only send null/0 values + // Initializing everything because Driver will only send null/0 values // Only the 1st worker in a node can be here. Others will get from cache initializeWorkerVariables() @@ -157,55 +157,55 @@ extends Broadcast[T](id) with Logging with Serializable { listenPortLock.synchronized { listenPortLock.wait() } } - var clientSocketToMaster: Socket = null - var oosMaster: ObjectOutputStream = null - var oisMaster: ObjectInputStream = null + var clientSocketToDriver: Socket = null + var oosDriver: ObjectOutputStream = null + var oisDriver: ObjectInputStream = null // Connect and receive broadcast from the specified source, retrying the // specified number of times in case of failures var retriesLeft = MultiTracker.MaxRetryCount do { - // Connect to Master and send this worker's Information - clientSocketToMaster = new Socket(MultiTracker.MasterHostAddress, gInfo.listenPort) - oosMaster = new ObjectOutputStream(clientSocketToMaster.getOutputStream) - oosMaster.flush() - oisMaster = new ObjectInputStream(clientSocketToMaster.getInputStream) + // Connect to Driver and send this worker's Information + clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort) + oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream) + oosDriver.flush() + oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream) - logDebug("Connected to Master's guiding object") + logDebug("Connected to Driver's guiding object") // Send local source information - oosMaster.writeObject(SourceInfo(hostAddress, listenPort)) - oosMaster.flush() + oosDriver.writeObject(SourceInfo(hostAddress, listenPort)) + oosDriver.flush() - // Receive source information from Master - var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] + // Receive source information from Driver + var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo] totalBlocks = sourceInfo.totalBlocks arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } totalBytes = sourceInfo.totalBytes - logDebug("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) + logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort) val start = System.nanoTime val receptionSucceeded = receiveSingleTransmission(sourceInfo) val time = (System.nanoTime - start) / 1e9 - // Updating some statistics in sourceInfo. Master will be using them later + // Updating some statistics in sourceInfo. Driver will be using them later if (!receptionSucceeded) { sourceInfo.receptionFailed = true } - // Send back statistics to the Master - oosMaster.writeObject(sourceInfo) + // Send back statistics to the Driver + oosDriver.writeObject(sourceInfo) - if (oisMaster != null) { - oisMaster.close() + if (oisDriver != null) { + oisDriver.close() } - if (oosMaster != null) { - oosMaster.close() + if (oosDriver != null) { + oosDriver.close() } - if (clientSocketToMaster != null) { - clientSocketToMaster.close() + if (clientSocketToDriver != null) { + clientSocketToDriver.close() } retriesLeft -= 1 @@ -552,7 +552,7 @@ extends Broadcast[T](id) with Logging with Serializable { } private def sendObject() { - // Wait till receiving the SourceInfo from Master + // Wait till receiving the SourceInfo from Driver while (totalBlocks == -1) { totalBlocksLock.synchronized { totalBlocksLock.wait() } } @@ -576,7 +576,7 @@ extends Broadcast[T](id) with Logging with Serializable { private[spark] class TreeBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) } + def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new TreeBroadcast[T](value_, isLocal, id) diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index 4211d80596..ae083efc8d 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -10,7 +10,7 @@ import spark.{Logging, Utils} import scala.collection.mutable.ArrayBuffer private[spark] -class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging { +class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { val localIpAddress = Utils.localIpAddress @@ -19,33 +19,31 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) var masterPort : Int = _ var masterUrl : String = _ - val slaveActorSystems = ArrayBuffer[ActorSystem]() - val slaveActors = ArrayBuffer[ActorRef]() + val workerActorSystems = ArrayBuffer[ActorSystem]() + val workerActors = ArrayBuffer[ActorRef]() def start() : String = { - logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.") + logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0) masterActorSystem = actorSystem masterUrl = "spark://" + localIpAddress + ":" + masterPort - val actor = masterActorSystem.actorOf( + masterActor = masterActorSystem.actorOf( Props(new Master(localIpAddress, masterPort, 0)), name = "Master") - masterActor = actor - /* Start the Slaves */ - for (slaveNum <- 1 to numSlaves) { - /* We can pretend to test distributed stuff by giving the slaves distinct hostnames. + /* Start the Workers */ + for (workerNum <- 1 to numWorkers) { + /* We can pretend to test distributed stuff by giving the workers distinct hostnames. All of 127/8 should be a loopback, we use 127.100.*.* in hopes that it is sufficiently distinctive. */ - val slaveIpAddress = "127.100.0." + (slaveNum % 256) + val workerIpAddress = "127.100.0." + (workerNum % 256) val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("sparkWorker" + slaveNum, slaveIpAddress, 0) - slaveActorSystems += actorSystem - val actor = actorSystem.actorOf( - Props(new Worker(slaveIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)), + AkkaUtils.createActorSystem("sparkWorker" + workerNum, workerIpAddress, 0) + workerActorSystems += actorSystem + workerActors += actorSystem.actorOf( + Props(new Worker(workerIpAddress, boundPort, 0, coresPerWorker, memoryPerWorker, masterUrl)), name = "Worker") - slaveActors += actor } return masterUrl @@ -53,9 +51,9 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) def stop() { logInfo("Shutting down local Spark cluster.") - // Stop the slaves before the master so they don't get upset that it disconnected - slaveActorSystems.foreach(_.shutdown()) - slaveActorSystems.foreach(_.awaitTermination()) + // Stop the workers before the master so they don't get upset that it disconnected + workerActorSystems.foreach(_.shutdown()) + workerActorSystems.foreach(_.awaitTermination()) masterActorSystem.shutdown() masterActorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala index da6abcc9c2..7035f4b394 100644 --- a/core/src/main/scala/spark/deploy/client/ClientListener.scala +++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala @@ -12,7 +12,7 @@ private[spark] trait ClientListener { def disconnected(): Unit - def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int): Unit + def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit - def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit + def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit } diff --git a/core/src/main/scala/spark/deploy/master/JobInfo.scala b/core/src/main/scala/spark/deploy/master/JobInfo.scala index 130b031a2a..a274b21c34 100644 --- a/core/src/main/scala/spark/deploy/master/JobInfo.scala +++ b/core/src/main/scala/spark/deploy/master/JobInfo.scala @@ -10,7 +10,7 @@ private[spark] class JobInfo( val id: String, val desc: JobDescription, val submitDate: Date, - val actor: ActorRef) + val driver: ActorRef) { var state = JobState.WAITING var executors = new mutable.HashMap[Int, ExecutorInfo] diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 2c2cd0231b..3347207c6d 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -88,7 +88,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor execOption match { case Some(exec) => { exec.state = state - exec.job.actor ! ExecutorUpdated(execId, state, message, exitStatus) + exec.job.driver ! ExecutorUpdated(execId, state, message, exitStatus) if (ExecutorState.isFinished(state)) { val jobInfo = idToJob(jobId) // Remove this executor from the worker and job @@ -199,7 +199,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) - exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) + exec.job.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, @@ -221,19 +221,19 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor actorToWorker -= worker.actor addressToWorker -= worker.actor.path.address for (exec <- worker.executors.values) { - exec.job.actor ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None) + exec.job.driver ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None) exec.job.executors -= exec.id } } - def addJob(desc: JobDescription, actor: ActorRef): JobInfo = { + def addJob(desc: JobDescription, driver: ActorRef): JobInfo = { val now = System.currentTimeMillis() val date = new Date(now) - val job = new JobInfo(now, newJobId(date), desc, date, actor) + val job = new JobInfo(now, newJobId(date), desc, date, driver) jobs += job idToJob(job.id) = job - actorToJob(sender) = job - addressToJob(sender.path.address) = job + actorToJob(driver) = job + addressToJob(driver.path.address) = job return job } @@ -242,8 +242,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor logInfo("Removing job " + job.id) jobs -= job idToJob -= job.id - actorToJob -= job.actor - addressToWorker -= job.actor.path.address + actorToJob -= job.driver + addressToWorker -= job.driver.path.address completedJobs += job // Remember it in our history waitingJobs -= job for (exec <- job.executors.values) { diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index a29bf974d2..f80f1b5274 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -16,33 +16,33 @@ import spark.scheduler.cluster.RegisterSlave private[spark] class StandaloneExecutorBackend( executor: Executor, - masterUrl: String, - slaveId: String, + driverUrl: String, + workerId: String, hostname: String, cores: Int) extends Actor with ExecutorBackend with Logging { - var master: ActorRef = null + var driver: ActorRef = null override def preStart() { try { - logInfo("Connecting to master: " + masterUrl) - master = context.actorFor(masterUrl) - master ! RegisterSlave(slaveId, hostname, cores) + logInfo("Connecting to driver: " + driverUrl) + driver = context.actorFor(driverUrl) + driver ! RegisterSlave(workerId, hostname, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing + context.watch(driver) // Doesn't work with remote actors, but useful for testing } catch { case e: Exception => - logError("Failed to connect to master", e) + logError("Failed to connect to driver", e) System.exit(1) } } override def receive = { case RegisteredSlave(sparkProperties) => - logInfo("Successfully registered with master") + logInfo("Successfully registered with driver") executor.initialize(hostname, sparkProperties) case RegisterSlaveFailed(message) => @@ -55,24 +55,24 @@ private[spark] class StandaloneExecutorBackend( } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - master ! StatusUpdate(slaveId, taskId, state, data) + driver ! StatusUpdate(workerId, taskId, state, data) } } private[spark] object StandaloneExecutorBackend { - def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) { + def run(driverUrl: String, workerId: String, hostname: String, cores: Int) { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(new Executor, masterUrl, slaveId, hostname, cores)), + Props(new StandaloneExecutorBackend(new Executor, driverUrl, workerId, hostname, cores)), name = "Executor") actorSystem.awaitTermination() } def main(args: Array[String]) { if (args.length != 4) { - System.err.println("Usage: StandaloneExecutorBackend ") + System.err.println("Usage: StandaloneExecutorBackend ") System.exit(1) } run(args(0), args(1), args(2), args(3).toInt) diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 4f82cd96dd..866beb6d01 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,7 +19,7 @@ private[spark] class SparkDeploySchedulerBackend( var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt - val executorIdToSlaveId = new HashMap[String, String] + val executorIdToWorkerId = new HashMap[String, String] // Memory used by each executor (in megabytes) val executorMemory = { @@ -34,10 +34,11 @@ private[spark] class SparkDeploySchedulerBackend( override def start() { super.start() - val masterUrl = "akka://spark@%s:%s/user/%s".format( - System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), + // The endpoint for executors to talk to us + val driverUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), StandaloneSchedulerBackend.ACTOR_NAME) - val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") + val args = Seq(driverUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome) @@ -55,35 +56,35 @@ private[spark] class SparkDeploySchedulerBackend( } } - def connected(jobId: String) { + override def connected(jobId: String) { logInfo("Connected to Spark cluster with job ID " + jobId) } - def disconnected() { + override def disconnected() { if (!stopping) { logError("Disconnected from Spark cluster!") scheduler.error("Disconnected from Spark cluster") } } - def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) { - executorIdToSlaveId += id -> workerId + override def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int) { + executorIdToWorkerId += fullId -> workerId logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format( - id, host, cores, Utils.memoryMegabytesToString(memory))) + fullId, host, cores, Utils.memoryMegabytesToString(memory))) } - def executorRemoved(id: String, message: String, exitStatus: Option[Int]) { + override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) { val reason: ExecutorLossReason = exitStatus match { case Some(code) => ExecutorExited(code) case None => SlaveLost(message) } - logInfo("Executor %s removed: %s".format(id, message)) - executorIdToSlaveId.get(id) match { - case Some(slaveId) => - executorIdToSlaveId.remove(id) - scheduler.slaveLost(slaveId, reason) + logInfo("Executor %s removed: %s".format(fullId, message)) + executorIdToWorkerId.get(fullId) match { + case Some(workerId) => + executorIdToWorkerId.remove(fullId) + scheduler.slaveLost(workerId, reason) case None => - logInfo("No slave ID known for executor %s".format(id)) + logInfo("No worker ID known for executor %s".format(fullId)) } } } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala index 1386cd9d44..bea9dc4f23 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -6,7 +6,7 @@ import spark.util.SerializableBuffer private[spark] sealed trait StandaloneClusterMessage extends Serializable -// Master to slaves +// Driver to executors private[spark] case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage @@ -16,7 +16,7 @@ case class RegisteredSlave(sparkProperties: Seq[(String, String)]) extends Stand private[spark] case class RegisterSlaveFailed(message: String) extends StandaloneClusterMessage -// Slaves to master +// Executors to driver private[spark] case class RegisterSlave(slaveId: String, host: String, cores: Int) extends StandaloneClusterMessage @@ -32,6 +32,6 @@ object StatusUpdate { } } -// Internal messages in master +// Internal messages in driver private[spark] case object ReviveOffers extends StandaloneClusterMessage -private[spark] case object StopMaster extends StandaloneClusterMessage +private[spark] case object StopDriver extends StandaloneClusterMessage diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index eeaae23dc8..d742a7b2bf 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -23,7 +23,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) - class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor { + class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { val slaveActor = new HashMap[String, ActorRef] val slaveAddress = new HashMap[String, Address] val slaveHost = new HashMap[String, String] @@ -37,34 +37,34 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } def receive = { - case RegisterSlave(slaveId, host, cores) => - if (slaveActor.contains(slaveId)) { - sender ! RegisterSlaveFailed("Duplicate slave ID: " + slaveId) + case RegisterSlave(workerId, host, cores) => + if (slaveActor.contains(workerId)) { + sender ! RegisterSlaveFailed("Duplicate slave ID: " + workerId) } else { - logInfo("Registered slave: " + sender + " with ID " + slaveId) + logInfo("Registered slave: " + sender + " with ID " + workerId) sender ! RegisteredSlave(sparkProperties) context.watch(sender) - slaveActor(slaveId) = sender - slaveHost(slaveId) = host - freeCores(slaveId) = cores - slaveAddress(slaveId) = sender.path.address - actorToSlaveId(sender) = slaveId - addressToSlaveId(sender.path.address) = slaveId + slaveActor(workerId) = sender + slaveHost(workerId) = host + freeCores(workerId) = cores + slaveAddress(workerId) = sender.path.address + actorToSlaveId(sender) = workerId + addressToSlaveId(sender.path.address) = workerId totalCoreCount.addAndGet(cores) makeOffers() } - case StatusUpdate(slaveId, taskId, state, data) => + case StatusUpdate(workerId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { - freeCores(slaveId) += 1 - makeOffers(slaveId) + freeCores(workerId) += 1 + makeOffers(workerId) } case ReviveOffers => makeOffers() - case StopMaster => + case StopDriver => sender ! true context.stop(self) @@ -85,9 +85,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } // Make fake resource offers on just one slave - def makeOffers(slaveId: String) { + def makeOffers(workerId: String) { launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(slaveId, slaveHost(slaveId), freeCores(slaveId))))) + Seq(new WorkerOffer(workerId, slaveHost(workerId), freeCores(workerId))))) } // Launch tasks returned by a set of resource offers @@ -99,24 +99,24 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } // Remove a disconnected slave from the cluster - def removeSlave(slaveId: String, reason: String) { - logInfo("Slave " + slaveId + " disconnected, so removing it") - val numCores = freeCores(slaveId) - actorToSlaveId -= slaveActor(slaveId) - addressToSlaveId -= slaveAddress(slaveId) - slaveActor -= slaveId - slaveHost -= slaveId - freeCores -= slaveId - slaveHost -= slaveId + def removeSlave(workerId: String, reason: String) { + logInfo("Slave " + workerId + " disconnected, so removing it") + val numCores = freeCores(workerId) + actorToSlaveId -= slaveActor(workerId) + addressToSlaveId -= slaveAddress(workerId) + slaveActor -= workerId + slaveHost -= workerId + freeCores -= workerId + slaveHost -= workerId totalCoreCount.addAndGet(-numCores) - scheduler.slaveLost(slaveId, SlaveLost(reason)) + scheduler.slaveLost(workerId, SlaveLost(reason)) } } - var masterActor: ActorRef = null + var driverActor: ActorRef = null val taskIdsOnSlave = new HashMap[String, HashSet[String]] - def start() { + override def start() { val properties = new ArrayBuffer[(String, String)] val iterator = System.getProperties.entrySet.iterator while (iterator.hasNext) { @@ -126,15 +126,15 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor properties += ((key, value)) } } - masterActor = actorSystem.actorOf( - Props(new MasterActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) + driverActor = actorSystem.actorOf( + Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) } - def stop() { + override def stop() { try { - if (masterActor != null) { + if (driverActor != null) { val timeout = 5.seconds - val future = masterActor.ask(StopMaster)(timeout) + val future = driverActor.ask(StopDriver)(timeout) Await.result(future, timeout) } } catch { @@ -143,11 +143,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } } - def reviveOffers() { - masterActor ! ReviveOffers + override def reviveOffers() { + driverActor ! ReviveOffers } - def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2) + override def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2) } private[spark] object StandaloneSchedulerBackend { diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index 014906b028..7bf56a05d6 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -104,11 +104,11 @@ private[spark] class CoarseMesosSchedulerBackend( def createCommand(offer: Offer, numCores: Int): CommandInfo = { val runScript = new File(sparkHome, "run").getCanonicalPath - val masterUrl = "akka://spark@%s:%s/user/%s".format( - System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), + val driverUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), StandaloneSchedulerBackend.ACTOR_NAME) val command = "\"%s\" spark.executor.StandaloneExecutorBackend %s %s %s %d".format( - runScript, masterUrl, offer.getSlaveId.getValue, offer.getHostname, numCores) + runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores) val environment = Environment.newBuilder() sc.executorEnvs.foreach { case (key, value) => environment.addVariables(Environment.Variable.newBuilder() diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index a3d8671834..9fd2b454a4 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -11,52 +11,51 @@ import akka.util.duration._ import spark.{Logging, SparkException, Utils} - private[spark] class BlockManagerMaster( val actorSystem: ActorSystem, - isMaster: Boolean, + isDriver: Boolean, isLocal: Boolean, - masterIp: String, - masterPort: Int) + driverIp: String, + driverPort: Int) extends Logging { val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt - val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager" + val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager" val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager" val DEFAULT_MANAGER_IP: String = Utils.localHostName() val timeout = 10.seconds - var masterActor: ActorRef = { - if (isMaster) { - val masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), - name = MASTER_AKKA_ACTOR_NAME) + var driverActor: ActorRef = { + if (isDriver) { + val driverActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), + name = DRIVER_AKKA_ACTOR_NAME) logInfo("Registered BlockManagerMaster Actor") - masterActor + driverActor } else { - val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME) + val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, DRIVER_AKKA_ACTOR_NAME) logInfo("Connecting to BlockManagerMaster: " + url) actorSystem.actorFor(url) } } - /** Remove a dead host from the master actor. This is only called on the master side. */ + /** Remove a dead host from the driver actor. This is only called on the driver side. */ def notifyADeadHost(host: String) { tell(RemoveHost(host)) logInfo("Removed " + host + " successfully in notifyADeadHost") } /** - * Send the master actor a heart beat from the slave. Returns true if everything works out, - * false if the master does not know about the given block manager, which means the block + * Send the driver actor a heart beat from the slave. Returns true if everything works out, + * false if the driver does not know about the given block manager, which means the block * manager should re-register. */ def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { - askMasterWithRetry[Boolean](HeartBeat(blockManagerId)) + askDriverWithReply[Boolean](HeartBeat(blockManagerId)) } - /** Register the BlockManager's id with the master. */ + /** Register the BlockManager's id with the driver. */ def registerBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { logInfo("Trying to register BlockManager") @@ -70,25 +69,25 @@ private[spark] class BlockManagerMaster( storageLevel: StorageLevel, memSize: Long, diskSize: Long): Boolean = { - val res = askMasterWithRetry[Boolean]( + val res = askDriverWithReply[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) logInfo("Updated info of block " + blockId) res } - /** Get locations of the blockId from the master */ + /** Get locations of the blockId from the driver */ def getLocations(blockId: String): Seq[BlockManagerId] = { - askMasterWithRetry[Seq[BlockManagerId]](GetLocations(blockId)) + askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) } - /** Get locations of multiple blockIds from the master */ + /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { - askMasterWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } - /** Get ids of other nodes in the cluster from the master */ + /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { - val result = askMasterWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) + val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) if (result.length != numPeers) { throw new SparkException( "Error getting peers, only got " + result.size + " instead of " + numPeers) @@ -98,10 +97,10 @@ private[spark] class BlockManagerMaster( /** * Remove a block from the slaves that have it. This can only be used to remove - * blocks that the master knows about. + * blocks that the driver knows about. */ def removeBlock(blockId: String) { - askMasterWithRetry(RemoveBlock(blockId)) + askDriverWithReply(RemoveBlock(blockId)) } /** @@ -111,33 +110,33 @@ private[spark] class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } - /** Stop the master actor, called only on the Spark master node */ + /** Stop the driver actor, called only on the Spark driver node */ def stop() { - if (masterActor != null) { + if (driverActor != null) { tell(StopBlockManagerMaster) - masterActor = null + driverActor = null logInfo("BlockManagerMaster stopped") } } /** Send a one-way message to the master actor, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!askMasterWithRetry[Boolean](message)) { + if (!askDriverWithReply[Boolean](message)) { throw new SparkException("BlockManagerMasterActor returned false, expected true.") } } /** - * Send a message to the master actor and get its result within a default timeout, or + * Send a message to the driver actor and get its result within a default timeout, or * throw a SparkException if this fails. */ - private def askMasterWithRetry[T](message: Any): T = { + private def askDriverWithReply[T](message: Any): T = { // TODO: Consider removing multiple attempts - if (masterActor == null) { - throw new SparkException("Error sending message to BlockManager as masterActor is null " + + if (driverActor == null) { + throw new SparkException("Error sending message to BlockManager as driverActor is null " + "[message = " + message + "]") } var attempts = 0 @@ -145,7 +144,7 @@ private[spark] class BlockManagerMaster( while (attempts < AKKA_RETRY_ATTEMPS) { attempts += 1 try { - val future = masterActor.ask(message)(timeout) + val future = driverActor.ask(message)(timeout) val result = Await.result(future, timeout) if (result == null) { throw new Exception("BlockManagerMaster returned null") diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala index 689f07b969..0b8f6d4303 100644 --- a/core/src/main/scala/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/spark/storage/ThreadingTest.scala @@ -75,9 +75,9 @@ private[spark] object ThreadingTest { System.setProperty("spark.kryoserializer.buffer.mb", "1") val actorSystem = ActorSystem("test") val serializer = new KryoSerializer - val masterIp: String = System.getProperty("spark.master.host", "localhost") - val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt - val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort) + val driverIp: String = System.getProperty("spark.driver.host", "localhost") + val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt + val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, driverIp, driverPort) val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 01351de4ae..42ce6f3c74 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -46,7 +46,7 @@ public class JavaAPISuite implements Serializable { sc.stop(); sc = null; // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port"); + System.clearProperty("spark.driver.port"); } static class ReverseIntComparator implements Comparator, Serializable { diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala index b5e31ddae3..ff00dd05dd 100644 --- a/core/src/test/scala/spark/LocalSparkContext.scala +++ b/core/src/test/scala/spark/LocalSparkContext.scala @@ -26,7 +26,7 @@ object LocalSparkContext { def stop(sc: SparkContext) { sc.stop() // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 7d5305f1e0..718107d2b5 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -79,7 +79,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("remote fetch") { val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) - System.setProperty("spark.master.port", boundPort.toString) + System.setProperty("spark.driver.port", boundPort.toString) val masterTracker = new MapOutputTracker(actorSystem, true) val slaveTracker = new MapOutputTracker(actorSystem, false) masterTracker.registerShuffle(10, 1) diff --git a/docs/configuration.md b/docs/configuration.md index 036a0df480..a7054b4321 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -202,7 +202,7 @@ Apart from these, the following properties are also available, and may be useful 10 Maximum message size to allow in "control plane" communication (for serialized tasks and task - results), in MB. Increase this if your tasks need to send back large results to the master + results), in MB. Increase this if your tasks need to send back large results to the driver (e.g. using collect() on a large dataset). @@ -211,7 +211,7 @@ Apart from these, the following properties are also available, and may be useful 4 Number of actor threads to use for communication. Can be useful to increase on large clusters - when the master has a lot of CPU cores. + when the driver has a lot of CPU cores. @@ -222,17 +222,17 @@ Apart from these, the following properties are also available, and may be useful - spark.master.host + spark.driver.host (local hostname) - Hostname or IP address for the master to listen on. + Hostname or IP address for the driver to listen on. - spark.master.port + spark.driver.port (random) - Port for the master to listen on. + Port for the driver to listen on. diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 46ab34f063..df7235756d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -26,7 +26,7 @@ class PySparkTestCase(unittest.TestCase): sys.path = self._old_sys_path # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown - self.sc.jvm.System.clearProperty("spark.master.port") + self.sc.jvm.System.clearProperty("spark.driver.port") class TestCheckpoint(PySparkTestCase): diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala index db78d06d4f..43559b96d3 100644 --- a/repl/src/test/scala/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/spark/repl/ReplSuite.scala @@ -31,7 +31,7 @@ class ReplSuite extends FunSuite { if (interp.sparkContext != null) interp.sparkContext.stop() // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") return out.toString } diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala index aa6be95f30..8c322dd698 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala @@ -153,8 +153,8 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log /** A helper actor that communicates with the NetworkInputTracker */ private class NetworkReceiverActor extends Actor { logInfo("Attempting to register with tracker") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt + val ip = System.getProperty("spark.driver.host", "localhost") + val port = System.getProperty("spark.driver.port", "7077").toInt val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) val tracker = env.actorSystem.actorFor(url) val timeout = 5.seconds diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index c84e7331c7..79d6093429 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -43,7 +43,7 @@ public class JavaAPISuite implements Serializable { ssc = null; // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port"); + System.clearProperty("spark.driver.port"); } @Test diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index bfdf32c73e..4a036f0710 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -10,7 +10,7 @@ class BasicOperationsSuite extends TestSuiteBase { after { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } test("map") { diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index d2f32c189b..563a7d1458 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -19,7 +19,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { FileUtils.deleteDirectory(new File(checkpointDir)) // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } var ssc: StreamingContext = null diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index 7493ac1207..c4cfffbfc1 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -24,7 +24,7 @@ class FailureSuite extends TestSuiteBase with BeforeAndAfter { FileUtils.deleteDirectory(new File(checkpointDir)) // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } override def framework = "CheckpointSuite" diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index d7ba7a5d17..70ae6e3934 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -42,7 +42,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } test("network input stream") { diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index 0c6e928835..cd9608df53 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -13,7 +13,7 @@ class WindowOperationsSuite extends TestSuiteBase { after { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } val largerSlideInput = Seq( -- cgit v1.2.3 From 539491bbc333834b9ae2721ae6cf3524cefb91ea Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 25 Jan 2013 09:29:59 -0800 Subject: code reformatting --- core/src/main/scala/spark/RDD.scala | 4 ++-- core/src/main/scala/spark/storage/BlockManagerUI.scala | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 870cc5ca78..4fcab9279a 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -94,7 +94,7 @@ abstract class RDD[T: ClassManifest]( /** How this RDD depends on any parent RDDs. */ protected def getDependencies(): List[Dependency[_]] = dependencies_ - // A friendly name for this RDD + /** A friendly name for this RDD */ var name: String = null /** Optionally overridden by subclasses to specify placement preferences. */ @@ -111,7 +111,7 @@ abstract class RDD[T: ClassManifest]( /** A unique ID for this RDD (within its SparkContext). */ val id = sc.newRddId() - /* Assign a name to this RDD */ + /** Assign a name to this RDD */ def setName(_name: String) = { name = _name this diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 35cbd59280..1003cc7a61 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -57,7 +57,8 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) - spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) + spark.storage.html.index. + render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) } }}} ~ get { path("rdd") { parameter("id") { id => { completeWith { @@ -67,9 +68,10 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray - val filteredStorageStatusList = StorageUtils.filterStorageStatusByPrefix(storageStatusList, prefix) + val filteredStorageStatusList = StorageUtils. + filterStorageStatusByPrefix(storageStatusList, prefix) - val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).first + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) -- cgit v1.2.3 From 1cadaa164e9f078e4ca483edb9db7fd5507c9e64 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 25 Jan 2013 09:30:21 -0800 Subject: switch to TimeStampedHashMap for storing persistent Rdds --- core/src/main/scala/spark/SparkContext.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d994648899..10ceeb3028 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -44,6 +44,7 @@ import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} +import util.TimeStampedHashMap /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -110,7 +111,7 @@ class SparkContext( private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new ConcurrentHashMap[Int, RDD[_]]() + private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() // Add each JAR given through the constructor jars.foreach { addJar(_) } -- cgit v1.2.3 From a1d9d1767d821c1e25e485e32d9356b12aba6a01 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 25 Jan 2013 10:05:26 -0800 Subject: fixup 1cadaa1, changed api of map --- core/src/main/scala/spark/storage/StorageUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index 63ad5c125b..a10e3a95c6 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -56,8 +56,8 @@ object StorageUtils { // Find the id of the RDD, e.g. rdd_1 => 1 val rddId = rddKey.split("_").last.toInt // Get the friendly name for the rdd, if available. - val rddName = Option(sc.persistentRdds.get(rddId).name).getOrElse(rddKey) - val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel + val rddName = Option(sc.persistentRdds(rddId).name).getOrElse(rddKey) + val rddStorageLevel = sc.persistentRdds(rddId).getStorageLevel RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize) }.toArray -- cgit v1.2.3 From 8efbda0b179e3821a1221c6d78681fc74248cdac Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 25 Jan 2013 14:55:33 -0600 Subject: Call executeOnCompleteCallbacks in more finally blocks. --- .../main/scala/spark/scheduler/DAGScheduler.scala | 13 +++--- .../scala/spark/scheduler/ShuffleMapTask.scala | 46 +++++++++++----------- 2 files changed, 30 insertions(+), 29 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b320be8863..f599eb00bd 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -40,7 +40,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with eventQueue.put(HostLost(host)) } - // Called by TaskScheduler to cancel an entier TaskSet due to repeated failures. + // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. override def taskSetFailed(taskSet: TaskSet, reason: String) { eventQueue.put(TaskSetFailed(taskSet, reason)) } @@ -54,8 +54,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // resubmit failed stages val POLL_TIMEOUT = 10L - private val lock = new Object // Used for access to the entire DAGScheduler - private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] val nextRunId = new AtomicInteger(0) @@ -337,9 +335,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val rdd = job.finalStage.rdd val split = rdd.splits(job.partitions(0)) val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) - val result = job.func(taskContext, rdd.iterator(split, taskContext)) - taskContext.executeOnCompleteCallbacks() - job.listener.taskSucceeded(0, result) + try { + val result = job.func(taskContext, rdd.iterator(split, taskContext)) + job.listener.taskSucceeded(0, result) + } finally { + taskContext.executeOnCompleteCallbacks() + } } catch { case e: Exception => job.listener.jobFailed(e) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 19f5328eee..83641a2a84 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -81,7 +81,7 @@ private[spark] class ShuffleMapTask( with Externalizable with Logging { - def this() = this(0, null, null, 0, null) + protected def this() = this(0, null, null, 0, null) var split = if (rdd == null) { null @@ -117,34 +117,34 @@ private[spark] class ShuffleMapTask( override def run(attemptId: Long): MapStatus = { val numOutputSplits = dep.partitioner.numPartitions - val partitioner = dep.partitioner val taskContext = new TaskContext(stageId, partition, attemptId) + try { + // Partition the map output. + val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) + for (elem <- rdd.iterator(split, taskContext)) { + val pair = elem.asInstanceOf[(Any, Any)] + val bucketId = dep.partitioner.getPartition(pair._1) + buckets(bucketId) += pair + } + val bucketIterators = buckets.map(_.iterator) - // Partition the map output. - val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) - for (elem <- rdd.iterator(split, taskContext)) { - val pair = elem.asInstanceOf[(Any, Any)] - val bucketId = partitioner.getPartition(pair._1) - buckets(bucketId) += pair - } - val bucketIterators = buckets.map(_.iterator) + val compressedSizes = new Array[Byte](numOutputSplits) - val compressedSizes = new Array[Byte](numOutputSplits) + val blockManager = SparkEnv.get.blockManager + for (i <- 0 until numOutputSplits) { + val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i + // Get a Scala iterator from Java map + val iter: Iterator[(Any, Any)] = bucketIterators(i) + val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) + compressedSizes(i) = MapOutputTracker.compressSize(size) + } - val blockManager = SparkEnv.get.blockManager - for (i <- 0 until numOutputSplits) { - val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i - // Get a Scala iterator from Java map - val iter: Iterator[(Any, Any)] = bucketIterators(i) - val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) - compressedSizes(i) = MapOutputTracker.compressSize(size) + return new MapStatus(blockManager.blockManagerId, compressedSizes) + } finally { + // Execute the callbacks on task completion. + taskContext.executeOnCompleteCallbacks() } - - // Execute the callbacks on task completion. - taskContext.executeOnCompleteCallbacks() - - return new MapStatus(blockManager.blockManagerId, compressedSizes) } override def preferredLocations: Seq[String] = locs -- cgit v1.2.3 From 49c05608f5f27354da120e2367b6d4a63ec38948 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 25 Jan 2013 17:04:16 -0800 Subject: add metadatacleaner for persisentRdd map --- core/src/main/scala/spark/SparkContext.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 10ceeb3028..bff54dbdd1 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -44,7 +44,7 @@ import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import util.TimeStampedHashMap +import util.{MetadataCleaner, TimeStampedHashMap} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -113,6 +113,9 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() + private[spark] val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) + + // Add each JAR given through the constructor jars.foreach { addJar(_) } @@ -512,6 +515,7 @@ class SparkContext( /** Shut down the SparkContext. */ def stop() { if (dagScheduler != null) { + metadataCleaner.cancel() dagScheduler.stop() dagScheduler = null taskScheduler = null @@ -654,6 +658,12 @@ class SparkContext( /** Register a new RDD, returning its RDD ID */ private[spark] def newRddId(): Int = nextRddId.getAndIncrement() + + private[spark] def cleanup(cleanupTime: Long) { + var sizeBefore = persistentRdds.size + persistentRdds.clearOldValues(cleanupTime) + logInfo("idToStage " + sizeBefore + " --> " + persistentRdds.size) + } } /** -- cgit v1.2.3 From d49cf0e587b7cbbd31917d9bb69f98466feb0f9f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 26 Jan 2013 15:57:01 -0800 Subject: Fix JavaRDDLike.flatMap(PairFlatMapFunction) (SPARK-668). This workaround is easier than rewriting JavaRDDLike in Java. --- .../main/scala/spark/api/java/JavaRDDLike.scala | 7 +++--- .../spark/api/java/PairFlatMapWorkaround.java | 20 ++++++++++++++++ core/src/test/scala/spark/JavaAPISuite.java | 28 ++++++++++++++++++++++ 3 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index b3698ffa44..4c95c989b5 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -12,7 +12,7 @@ import spark.storage.StorageLevel import com.google.common.base.Optional -trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { +trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] { def wrapRDD(rdd: RDD[T]): This implicit val classManifest: ClassManifest[T] @@ -82,10 +82,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. + * Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java. */ - def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = { + private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = { import scala.collection.JavaConverters._ def fn = (x: T) => f.apply(x).asScala def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] diff --git a/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java new file mode 100644 index 0000000000..68b6fd6622 --- /dev/null +++ b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java @@ -0,0 +1,20 @@ +package spark.api.java; + +import spark.api.java.JavaPairRDD; +import spark.api.java.JavaRDDLike; +import spark.api.java.function.PairFlatMapFunction; + +import java.io.Serializable; + +/** + * Workaround for SPARK-668. + */ +class PairFlatMapWorkaround implements Serializable { + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ + public JavaPairRDD flatMap(PairFlatMapFunction f) { + return ((JavaRDDLike ) this).doFlatMap(f); + } +} diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 01351de4ae..f50ba093e9 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -355,6 +355,34 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(11, pairs.count()); } + @Test + public void mapsFromPairsToPairs() { + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD swapped = pairRDD.flatMap( + new PairFlatMapFunction, String, Integer>() { + @Override + public Iterable> call(Tuple2 item) throws Exception { + return Collections.singletonList(item.swap()); + } + }); + swapped.collect(); + + // There was never a bug here, but it's worth testing: + pairRDD.map(new PairFunction, String, Integer>() { + @Override + public Tuple2 call(Tuple2 item) throws Exception { + return item.swap(); + } + }).collect(); + } + @Test public void mapPartitions() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); -- cgit v1.2.3 From ad4232b4dadc6290d3c4696d3cc007d3f01cb236 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Sat, 26 Jan 2013 18:07:14 -0800 Subject: Fix deadlock in BlockManager reregistration triggered by failed updates. --- .../main/scala/spark/storage/BlockManager.scala | 35 +++++++++++++++++-- .../scala/spark/storage/BlockManagerSuite.scala | 40 +++++++++++++++++++++- 2 files changed, 72 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 19cdaaa984..19d35b8667 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -90,7 +90,10 @@ class BlockManager( val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) - @volatile private var shuttingDown = false + // Pending reregistration action being executed asynchronously or null if none + // is pending. Accesses should synchronize on asyncReregisterLock. + var asyncReregisterTask: Future[Unit] = null + val asyncReregisterLock = new Object private def heartBeat() { if (!master.sendHeartBeat(blockManagerId)) { @@ -147,6 +150,8 @@ class BlockManager( /** * Reregister with the master and report all blocks to it. This will be called by the heart beat * thread if our heartbeat to the block amnager indicates that we were not registered. + * + * Note that this method must be called without any BlockInfo locks held. */ def reregister() { // TODO: We might need to rate limit reregistering. @@ -155,6 +160,32 @@ class BlockManager( reportAllBlocks() } + /** + * Reregister with the master sometime soon. + */ + def asyncReregister() { + asyncReregisterLock.synchronized { + if (asyncReregisterTask == null) { + asyncReregisterTask = Future[Unit] { + reregister() + asyncReregisterLock.synchronized { + asyncReregisterTask = null + } + } + } + } + } + + /** + * For testing. Wait for any pending asynchronous reregistration; otherwise, do nothing. + */ + def waitForAsyncReregister() { + val task = asyncReregisterTask + if (task != null) { + Await.ready(task, Duration.Inf) + } + } + /** * Get storage level of local block. If no info exists for the block, then returns null. */ @@ -170,7 +201,7 @@ class BlockManager( if (needReregister) { logInfo("Got told to reregister updating block " + blockId) // Reregistering will report our new block for free. - reregister() + asyncReregister() } logDebug("Told master about block " + blockId) } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index a1aeb12f25..2165744689 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -219,18 +219,56 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val a2 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(master.getLocations("a1").size > 0, "master was not told about a1") master.notifyADeadHost(store.blockManagerId.ip) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY) + store.waitForAsyncReregister() assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") assert(master.getLocations("a2").size > 0, "master was not told about a2") } + test("reregistration doesn't dead lock") { + val heartBeat = PrivateMethod[Unit]('heartBeat) + store = new BlockManager(actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = List(new Array[Byte](400)) + + // try many times to trigger any deadlocks + for (i <- 1 to 100) { + master.notifyADeadHost(store.blockManagerId.ip) + val t1 = new Thread { + override def run = { + store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true) + } + } + val t2 = new Thread { + override def run = { + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + } + } + val t3 = new Thread { + override def run = { + store invokePrivate heartBeat() + } + } + + t1.start + t2.start + t3.start + t1.join + t2.join + t3.join + + store.dropFromMemory("a1", null) + store.dropFromMemory("a2", null) + store.waitForAsyncReregister() + } + } + test("in-memory LRU storage") { store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) -- cgit v1.2.3 From 58fc6b2bed9f660fbf134aab188827b7d8975a62 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Sat, 26 Jan 2013 18:07:53 -0800 Subject: Handle duplicate registrations better. --- core/src/main/scala/spark/storage/BlockManagerMasterActor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index f4d026da33..2216c33b76 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -183,7 +183,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { if (blockManagerId.ip == Utils.localHostName() && !isLocal) { logInfo("Got Register Msg from master node, don't register it") - } else { + } else if (!blockManagerInfo.contains(blockManagerId)) { blockManagerIdByHost.get(blockManagerId.ip) match { case Some(managers) => // A block manager of the same host name already exists. -- cgit v1.2.3 From 44b4a0f88fcb31727347b755ae8ec14d69571b52 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 27 Jan 2013 19:23:49 -0800 Subject: Track workers by executor ID instead of hostname to allow multiple executors per machine and remove the need for multiple IP addresses in unit tests. --- core/src/main/scala/spark/MapOutputTracker.scala | 4 +- core/src/main/scala/spark/SparkContext.scala | 6 +- core/src/main/scala/spark/SparkEnv.scala | 9 +- .../scala/spark/deploy/LocalSparkCluster.scala | 16 +-- .../main/scala/spark/deploy/master/Master.scala | 4 +- .../scala/spark/deploy/worker/ExecutorRunner.scala | 2 +- core/src/main/scala/spark/executor/Executor.scala | 4 +- .../spark/executor/MesosExecutorBackend.scala | 3 +- .../spark/executor/StandaloneExecutorBackend.scala | 14 +-- .../main/scala/spark/scheduler/DAGScheduler.scala | 44 ++++----- .../scala/spark/scheduler/DAGSchedulerEvent.scala | 2 +- .../src/main/scala/spark/scheduler/MapStatus.scala | 6 +- core/src/main/scala/spark/scheduler/Stage.scala | 11 ++- .../spark/scheduler/TaskSchedulerListener.scala | 2 +- .../spark/scheduler/cluster/ClusterScheduler.scala | 110 +++++++++++---------- .../cluster/SparkDeploySchedulerBackend.scala | 4 +- .../cluster/StandaloneSchedulerBackend.scala | 64 ++++++------ .../spark/scheduler/cluster/TaskDescription.scala | 2 +- .../scala/spark/scheduler/cluster/TaskInfo.scala | 7 +- .../spark/scheduler/cluster/TaskSetManager.scala | 38 +++---- .../spark/scheduler/cluster/WorkerOffer.scala | 4 +- .../scheduler/mesos/MesosSchedulerBackend.scala | 2 +- .../main/scala/spark/storage/BlockManager.scala | 10 +- .../main/scala/spark/storage/BlockManagerId.scala | 27 +++-- .../scala/spark/storage/BlockManagerMaster.scala | 12 +-- .../spark/storage/BlockManagerMasterActor.scala | 66 +++++-------- .../scala/spark/storage/BlockManagerMessages.scala | 2 +- .../main/scala/spark/storage/BlockManagerUI.scala | 7 +- .../main/scala/spark/storage/ThreadingTest.scala | 3 +- core/src/main/scala/spark/util/AkkaUtils.scala | 6 +- .../main/scala/spark/util/TimeStampedHashMap.scala | 4 +- core/src/test/scala/spark/DriverSuite.scala | 5 +- .../test/scala/spark/MapOutputTrackerSuite.scala | 69 +++++++------ .../scala/spark/storage/BlockManagerSuite.scala | 86 ++++++++-------- sbt/sbt | 2 +- 35 files changed, 343 insertions(+), 314 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index ac02f3363a..c1f012b419 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -114,7 +114,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea var array = mapStatuses(shuffleId) if (array != null) { array.synchronized { - if (array(mapId) != null && array(mapId).address == bmAddress) { + if (array(mapId) != null && array(mapId).location == bmAddress) { array(mapId) = null } } @@ -277,7 +277,7 @@ private[spark] object MapOutputTracker { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing an output location for shuffle " + shuffleId)) } else { - (status.address, decompressSize(status.compressedSizes(reduceId))) + (status.location, decompressSize(status.compressedSizes(reduceId))) } } } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4581c0adcf..39721b47ae 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -80,6 +80,7 @@ class SparkContext( // Create the Spark execution environment (cache, map output tracker, etc) private[spark] val env = SparkEnv.createFromSystemProperties( + "", System.getProperty("spark.master.host"), System.getProperty("spark.master.port").toInt, true, @@ -97,7 +98,7 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() - private[spark] val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) + private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) // Add each JAR given through the constructor @@ -649,10 +650,9 @@ class SparkContext( /** Register a new RDD, returning its RDD ID */ private[spark] def newRddId(): Int = nextRddId.getAndIncrement() + /** Called by MetadataCleaner to clean up the persistentRdds map periodically */ private[spark] def cleanup(cleanupTime: Long) { - var sizeBefore = persistentRdds.size persistentRdds.clearOldValues(cleanupTime) - logInfo("idToStage " + sizeBefore + " --> " + persistentRdds.size) } } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 2a7a8af83d..0c094edcf3 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -19,6 +19,7 @@ import spark.util.AkkaUtils * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set. */ class SparkEnv ( + val executorId: String, val actorSystem: ActorSystem, val serializer: Serializer, val closureSerializer: Serializer, @@ -58,11 +59,12 @@ object SparkEnv extends Logging { } def createFromSystemProperties( + executorId: String, hostname: String, port: Int, isMaster: Boolean, - isLocal: Boolean - ) : SparkEnv = { + isLocal: Boolean): SparkEnv = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port) // Bit of a hack: If this is the master and our port was 0 (meaning bind to any free port), @@ -86,7 +88,7 @@ object SparkEnv extends Logging { val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt val blockManagerMaster = new BlockManagerMaster( actorSystem, isMaster, isLocal, masterIp, masterPort) - val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer) + val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer) val connectionManager = blockManager.connectionManager @@ -122,6 +124,7 @@ object SparkEnv extends Logging { } new SparkEnv( + executorId, actorSystem, serializer, closureSerializer, diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index 4211d80596..8f51051e39 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -9,6 +9,12 @@ import spark.{Logging, Utils} import scala.collection.mutable.ArrayBuffer +/** + * Testing class that creates a Spark standalone process in-cluster (that is, running the + * spark.deploy.master.Master and spark.deploy.worker.Workers in the same JVMs). Executors launched + * by the Workers still run in separate JVMs. This can be used to test distributed operation and + * fault recovery without spinning up a lot of processes. + */ private[spark] class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging { @@ -35,16 +41,12 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) /* Start the Slaves */ for (slaveNum <- 1 to numSlaves) { - /* We can pretend to test distributed stuff by giving the slaves distinct hostnames. - All of 127/8 should be a loopback, we use 127.100.*.* in hopes that it is - sufficiently distinctive. */ - val slaveIpAddress = "127.100.0." + (slaveNum % 256) val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("sparkWorker" + slaveNum, slaveIpAddress, 0) + AkkaUtils.createActorSystem("sparkWorker" + slaveNum, localIpAddress, 0) slaveActorSystems += actorSystem val actor = actorSystem.actorOf( - Props(new Worker(slaveIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)), - name = "Worker") + Props(new Worker(localIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)), + name = "Worker") slaveActors += actor } diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 2c2cd0231b..2e7e868579 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -97,10 +97,10 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor exec.worker.removeExecutor(exec) // Only retry certain number of times so we don't go into an infinite loop. - if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) { + if (jobInfo.incrementRetryCount < JobState.MAX_NUM_RETRY) { schedule() } else { - val e = new SparkException("Job %s wth ID %s failed %d times.".format( + val e = new SparkException("Job %s with ID %s failed %d times.".format( jobInfo.desc.name, jobInfo.id, jobInfo.retryCount)) logError(e.getMessage, e) throw e diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 0d1fe2a6b4..af3acfecb6 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -67,7 +67,7 @@ private[spark] class ExecutorRunner( /** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { - case "{{SLAVEID}}" => workerId + case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => hostname case "{{CORES}}" => cores.toString case other => other diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 28d9d40d43..bd21ba719a 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -30,7 +30,7 @@ private[spark] class Executor extends Logging { initLogging() - def initialize(slaveHostname: String, properties: Seq[(String, String)]) { + def initialize(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) { // Make sure the local hostname we report matches the cluster scheduler's name for this host Utils.setCustomHostname(slaveHostname) @@ -64,7 +64,7 @@ private[spark] class Executor extends Logging { ) // Initialize Spark environment (using system properties read above) - env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false) + env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) SparkEnv.set(env) // Start worker thread pool diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala index eeab3959c6..1ef88075ad 100644 --- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala @@ -29,9 +29,10 @@ private[spark] class MesosExecutorBackend(executor: Executor) executorInfo: ExecutorInfo, frameworkInfo: FrameworkInfo, slaveInfo: SlaveInfo) { + logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) this.driver = driver val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) - executor.initialize(slaveInfo.getHostname, properties) + executor.initialize(executorInfo.getExecutorId.getValue, slaveInfo.getHostname, properties) } override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index a29bf974d2..435ee5743e 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -17,7 +17,7 @@ import spark.scheduler.cluster.RegisterSlave private[spark] class StandaloneExecutorBackend( executor: Executor, masterUrl: String, - slaveId: String, + executorId: String, hostname: String, cores: Int) extends Actor @@ -30,7 +30,7 @@ private[spark] class StandaloneExecutorBackend( try { logInfo("Connecting to master: " + masterUrl) master = context.actorFor(masterUrl) - master ! RegisterSlave(slaveId, hostname, cores) + master ! RegisterSlave(executorId, hostname, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing } catch { @@ -43,7 +43,7 @@ private[spark] class StandaloneExecutorBackend( override def receive = { case RegisteredSlave(sparkProperties) => logInfo("Successfully registered with master") - executor.initialize(hostname, sparkProperties) + executor.initialize(executorId, hostname, sparkProperties) case RegisterSlaveFailed(message) => logError("Slave registration failed: " + message) @@ -55,24 +55,24 @@ private[spark] class StandaloneExecutorBackend( } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - master ! StatusUpdate(slaveId, taskId, state, data) + master ! StatusUpdate(executorId, taskId, state, data) } } private[spark] object StandaloneExecutorBackend { - def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) { + def run(masterUrl: String, executorId: String, hostname: String, cores: Int) { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(new Executor, masterUrl, slaveId, hostname, cores)), + Props(new StandaloneExecutorBackend(new Executor, masterUrl, executorId, hostname, cores)), name = "Executor") actorSystem.awaitTermination() } def main(args: Array[String]) { if (args.length != 4) { - System.err.println("Usage: StandaloneExecutorBackend ") + System.err.println("Usage: StandaloneExecutorBackend ") System.exit(1) } run(args(0), args(1), args(2), args(3).toInt) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index f599eb00bd..bd541d4207 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -35,9 +35,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with eventQueue.put(CompletionEvent(task, reason, result, accumUpdates)) } - // Called by TaskScheduler when a host fails. - override def hostLost(host: String) { - eventQueue.put(HostLost(host)) + // Called by TaskScheduler when an executor fails. + override def executorLost(execId: String) { + eventQueue.put(ExecutorLost(execId)) } // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. @@ -72,7 +72,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // For tracking failed nodes, we use the MapOutputTracker's generation number, which is // sent with every task. When we detect a node failing, we note the current generation number - // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask + // and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask // results. // TODO: Garbage collect information about failure generations when we know there are no more // stray messages to detect. @@ -108,7 +108,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } def clearCacheLocs() { - cacheLocs.clear + cacheLocs.clear() } /** @@ -271,8 +271,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with submitStage(finalStage) } - case HostLost(host) => - handleHostLost(host) + case ExecutorLost(execId) => + handleExecutorLost(execId) case completion: CompletionEvent => handleTaskCompletion(completion) @@ -436,10 +436,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with case smt: ShuffleMapTask => val stage = idToStage(smt.stageId) val status = event.result.asInstanceOf[MapStatus] - val host = status.address.ip - logInfo("ShuffleMapTask finished with host " + host) - if (failedGeneration.contains(host) && smt.generation <= failedGeneration(host)) { - logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + host) + val execId = status.location.executorId + logDebug("ShuffleMapTask finished on " + execId) + if (failedGeneration.contains(execId) && smt.generation <= failedGeneration(execId)) { + logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) } else { stage.addOutputLoc(smt.partition, status) } @@ -511,9 +511,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // Remember that a fetch failed now; this is used to resubmit the broken // stages later, after a small wait (to give other tasks the chance to fail) lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock - // TODO: mark the host as failed only if there were lots of fetch failures on it + // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleHostLost(bmAddress.ip, Some(task.generation)) + handleExecutorLost(bmAddress.executorId, Some(task.generation)) } case other => @@ -523,21 +523,21 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } /** - * Responds to a host being lost. This is called inside the event loop so it assumes that it can - * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside. + * Responds to an executor being lost. This is called inside the event loop, so it assumes it can + * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. * * Optionally the generation during which the failure was caught can be passed to avoid allowing * stray fetch failures from possibly retriggering the detection of a node as lost. */ - def handleHostLost(host: String, maybeGeneration: Option[Long] = None) { + def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) { val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) - if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) { - failedGeneration(host) = currentGeneration - logInfo("Host lost: " + host + " (generation " + currentGeneration + ")") - env.blockManager.master.notifyADeadHost(host) + if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) { + failedGeneration(execId) = currentGeneration + logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration)) + env.blockManager.master.removeExecutor(execId) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { - stage.removeOutputsOnHost(host) + stage.removeOutputsOnExecutor(execId) val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray mapOutputTracker.registerMapOutputs(shuffleId, locs, true) } @@ -546,7 +546,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } clearCacheLocs() } else { - logDebug("Additional host lost message for " + host + + logDebug("Additional executor lost message for " + execId + "(generation " + currentGeneration + ")") } } diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index 3422a21d9d..b34fa78c07 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -28,7 +28,7 @@ private[spark] case class CompletionEvent( accumUpdates: Map[Long, Any]) extends DAGSchedulerEvent -private[spark] case class HostLost(host: String) extends DAGSchedulerEvent +private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala index fae643f3a8..203abb917b 100644 --- a/core/src/main/scala/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/spark/scheduler/MapStatus.scala @@ -8,19 +8,19 @@ import java.io.{ObjectOutput, ObjectInput, Externalizable} * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. * The map output sizes are compressed using MapOutputTracker.compressSize. */ -private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: Array[Byte]) +private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte]) extends Externalizable { def this() = this(null, null) // For deserialization only def writeExternal(out: ObjectOutput) { - address.writeExternal(out) + location.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) } def readExternal(in: ObjectInput) { - address = BlockManagerId(in) + location = BlockManagerId(in) compressedSizes = new Array[Byte](in.readInt()) in.readFully(compressedSizes) } diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index 4846b66729..e9419728e3 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -51,18 +51,18 @@ private[spark] class Stage( def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.address == bmAddress) + val newList = prevList.filterNot(_.location == bmAddress) outputLocs(partition) = newList if (prevList != Nil && newList == Nil) { numAvailableOutputs -= 1 } } - def removeOutputsOnHost(host: String) { + def removeOutputsOnExecutor(execId: String) { var becameUnavailable = false for (partition <- 0 until numPartitions) { val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.address.ip == host) + val newList = prevList.filterNot(_.location.executorId == execId) outputLocs(partition) = newList if (prevList != Nil && newList == Nil) { becameUnavailable = true @@ -70,7 +70,8 @@ private[spark] class Stage( } } if (becameUnavailable) { - logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable)) + logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( + this, execId, numAvailableOutputs, numPartitions, isAvailable)) } } @@ -82,7 +83,7 @@ private[spark] class Stage( def origin: String = rdd.origin - override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]" + override def toString = "Stage " + id override def hashCode(): Int = id } diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala index fa4de15d0d..9fcef86e46 100644 --- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala +++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala @@ -12,7 +12,7 @@ private[spark] trait TaskSchedulerListener { def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit // A node was lost from the cluster. - def hostLost(host: String): Unit + def executorLost(execId: String): Unit // The TaskScheduler wants to abort an entire task set. def taskSetFailed(taskSet: TaskSet, reason: String): Unit diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index a639b72795..0b4177805b 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -27,19 +27,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext) var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] val taskIdToTaskSetId = new HashMap[Long, String] - val taskIdToSlaveId = new HashMap[Long, String] + val taskIdToExecutorId = new HashMap[Long, String] val taskSetTaskIds = new HashMap[String, HashSet[Long]] // Incrementing Mesos task IDs val nextTaskId = new AtomicLong(0) - // Which hosts in the cluster are alive (contains hostnames) - val hostsAlive = new HashSet[String] + // Which executor IDs we have executors on + val activeExecutorIds = new HashSet[String] - // Which slave IDs we have executors on - val slaveIdsWithExecutors = new HashSet[String] + // The set of executors we have on each host; this is used to compute hostsAlive, which + // in turn is used to decide when we can attain data locality on a given host + val executorsByHost = new HashMap[String, HashSet[String]] - val slaveIdToHost = new HashMap[String, String] + val executorIdToHost = new HashMap[String, String] // JAR server, if any JARs were added by the user to the SparkContext var jarServer: HttpServer = null @@ -102,7 +103,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) activeTaskSets -= manager.taskSet.id activeTaskSetsQueue -= manager taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id) + taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) taskSetTaskIds.remove(manager.taskSet.id) } } @@ -117,8 +118,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { - slaveIdToHost(o.slaveId) = o.hostname - hostsAlive += o.hostname + executorIdToHost(o.executorId) = o.hostname } // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) @@ -128,16 +128,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext) do { launchedTask = false for (i <- 0 until offers.size) { - val sid = offers(i).slaveId + val execId = offers(i).executorId val host = offers(i).hostname - manager.slaveOffer(sid, host, availableCpus(i)) match { + manager.slaveOffer(execId, host, availableCpus(i)) match { case Some(task) => tasks(i) += task val tid = task.taskId taskIdToTaskSetId(tid) = manager.taskSet.id taskSetTaskIds(manager.taskSet.id) += tid - taskIdToSlaveId(tid) = sid - slaveIdsWithExecutors += sid + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + if (!executorsByHost.contains(host)) { + executorsByHost(host) = new HashSet() + } + executorsByHost(host) += execId availableCpus(i) -= 1 launchedTask = true @@ -152,25 +156,21 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { var taskSetToUpdate: Option[TaskSetManager] = None - var failedHost: Option[String] = None + var failedExecutor: Option[String] = None var taskFailed = false synchronized { try { - if (state == TaskState.LOST && taskIdToSlaveId.contains(tid)) { - // We lost the executor on this slave, so remember that it's gone - val slaveId = taskIdToSlaveId(tid) - val host = slaveIdToHost(slaveId) - if (hostsAlive.contains(host)) { - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - activeTaskSetsQueue.foreach(_.hostLost(host)) - failedHost = Some(host) + if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { + // We lost this entire executor, so remember that it's gone + val execId = taskIdToExecutorId(tid) + if (activeExecutorIds.contains(execId)) { + removeExecutor(execId) + failedExecutor = Some(execId) } } taskIdToTaskSetId.get(tid) match { case Some(taskSetId) => if (activeTaskSets.contains(taskSetId)) { - //activeTaskSets(taskSetId).statusUpdate(status) taskSetToUpdate = Some(activeTaskSets(taskSetId)) } if (TaskState.isFinished(state)) { @@ -178,7 +178,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (taskSetTaskIds.contains(taskSetId)) { taskSetTaskIds(taskSetId) -= tid } - taskIdToSlaveId.remove(tid) + taskIdToExecutorId.remove(tid) } if (state == TaskState.FAILED) { taskFailed = true @@ -190,12 +190,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext) case e: Exception => logError("Exception in statusUpdate", e) } } - // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock + // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock if (taskSetToUpdate != None) { taskSetToUpdate.get.statusUpdate(tid, state, serializedData) } - if (failedHost != None) { - listener.hostLost(failedHost.get) + if (failedExecutor != None) { + listener.executorLost(failedExecutor.get) backend.reviveOffers() } if (taskFailed) { @@ -249,32 +249,42 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - def slaveLost(slaveId: String, reason: ExecutorLossReason) { - var failedHost: Option[String] = None + def executorLost(executorId: String, reason: ExecutorLossReason) { + var failedExecutor: Option[String] = None synchronized { - slaveIdToHost.get(slaveId) match { - case Some(host) => - if (hostsAlive.contains(host)) { - logError("Lost an executor on " + host + ": " + reason) - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - activeTaskSetsQueue.foreach(_.hostLost(host)) - failedHost = Some(host) - } else { - // We may get multiple slaveLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor on " + host + " (already removed): " + reason) - } - case None => - // We were told about a slave being lost before we could even allocate work to it - logError("Lost slave " + slaveId + " (no work assigned yet)") + if (activeExecutorIds.contains(executorId)) { + val host = executorIdToHost(executorId) + logError("Lost executor %s on %s: %s".format(executorId, host, reason)) + removeExecutor(executorId) + failedExecutor = Some(executorId) + } else { + // We may get multiple executorLost() calls with different loss reasons. For example, one + // may be triggered by a dropped connection from the slave while another may be a report + // of executor termination from Mesos. We produce log messages for both so we eventually + // report the termination reason. + logError("Lost an executor " + executorId + " (already removed): " + reason) } } - if (failedHost != None) { - listener.hostLost(failedHost.get) + // Call listener.executorLost without holding the lock on this to prevent deadlock + if (failedExecutor != None) { + listener.executorLost(failedExecutor.get) backend.reviveOffers() } } + + /** Get a list of hosts that currently have executors */ + def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet + + /** Remove an executor from all our data structures and mark it as lost */ + private def removeExecutor(executorId: String) { + activeExecutorIds -= executorId + val host = executorIdToHost(executorId) + val execs = executorsByHost.getOrElse(host, new HashSet) + execs -= executorId + if (execs.isEmpty) { + executorsByHost -= host + } + executorIdToHost -= executorId + activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 4f82cd96dd..f0792c1b76 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -37,7 +37,7 @@ private[spark] class SparkDeploySchedulerBackend( val masterUrl = "akka://spark@%s:%s/user/%s".format( System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), StandaloneSchedulerBackend.ACTOR_NAME) - val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") + val args = Seq(masterUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome) @@ -81,7 +81,7 @@ private[spark] class SparkDeploySchedulerBackend( executorIdToSlaveId.get(id) match { case Some(slaveId) => executorIdToSlaveId.remove(id) - scheduler.slaveLost(slaveId, reason) + scheduler.executorLost(slaveId, reason) case None => logInfo("No slave ID known for executor %s".format(id)) } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index eeaae23dc8..32be1e7a26 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -28,8 +28,8 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor val slaveAddress = new HashMap[String, Address] val slaveHost = new HashMap[String, String] val freeCores = new HashMap[String, Int] - val actorToSlaveId = new HashMap[ActorRef, String] - val addressToSlaveId = new HashMap[Address, String] + val actorToExecutorId = new HashMap[ActorRef, String] + val addressToExecutorId = new HashMap[Address, String] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -37,28 +37,28 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } def receive = { - case RegisterSlave(slaveId, host, cores) => - if (slaveActor.contains(slaveId)) { - sender ! RegisterSlaveFailed("Duplicate slave ID: " + slaveId) + case RegisterSlave(executorId, host, cores) => + if (slaveActor.contains(executorId)) { + sender ! RegisterSlaveFailed("Duplicate executor ID: " + executorId) } else { - logInfo("Registered slave: " + sender + " with ID " + slaveId) + logInfo("Registered executor: " + sender + " with ID " + executorId) sender ! RegisteredSlave(sparkProperties) context.watch(sender) - slaveActor(slaveId) = sender - slaveHost(slaveId) = host - freeCores(slaveId) = cores - slaveAddress(slaveId) = sender.path.address - actorToSlaveId(sender) = slaveId - addressToSlaveId(sender.path.address) = slaveId + slaveActor(executorId) = sender + slaveHost(executorId) = host + freeCores(executorId) = cores + slaveAddress(executorId) = sender.path.address + actorToExecutorId(sender) = executorId + addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) makeOffers() } - case StatusUpdate(slaveId, taskId, state, data) => + case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { - freeCores(slaveId) += 1 - makeOffers(slaveId) + freeCores(executorId) += 1 + makeOffers(executorId) } case ReviveOffers => @@ -69,13 +69,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor context.stop(self) case Terminated(actor) => - actorToSlaveId.get(actor).foreach(removeSlave(_, "Akka actor terminated")) + actorToExecutorId.get(actor).foreach(removeSlave(_, "Akka actor terminated")) case RemoteClientDisconnected(transport, address) => - addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client disconnected")) + addressToExecutorId.get(address).foreach(removeSlave(_, "remote Akka client disconnected")) case RemoteClientShutdown(transport, address) => - addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client shutdown")) + addressToExecutorId.get(address).foreach(removeSlave(_, "remote Akka client shutdown")) } // Make fake resource offers on all slaves @@ -85,31 +85,31 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } // Make fake resource offers on just one slave - def makeOffers(slaveId: String) { + def makeOffers(executorId: String) { launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(slaveId, slaveHost(slaveId), freeCores(slaveId))))) + Seq(new WorkerOffer(executorId, slaveHost(executorId), freeCores(executorId))))) } // Launch tasks returned by a set of resource offers def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { - freeCores(task.slaveId) -= 1 - slaveActor(task.slaveId) ! LaunchTask(task) + freeCores(task.executorId) -= 1 + slaveActor(task.executorId) ! LaunchTask(task) } } // Remove a disconnected slave from the cluster - def removeSlave(slaveId: String, reason: String) { - logInfo("Slave " + slaveId + " disconnected, so removing it") - val numCores = freeCores(slaveId) - actorToSlaveId -= slaveActor(slaveId) - addressToSlaveId -= slaveAddress(slaveId) - slaveActor -= slaveId - slaveHost -= slaveId - freeCores -= slaveId - slaveHost -= slaveId + def removeSlave(executorId: String, reason: String) { + logInfo("Slave " + executorId + " disconnected, so removing it") + val numCores = freeCores(executorId) + actorToExecutorId -= slaveActor(executorId) + addressToExecutorId -= slaveAddress(executorId) + slaveActor -= executorId + slaveHost -= executorId + freeCores -= executorId + slaveHost -= executorId totalCoreCount.addAndGet(-numCores) - scheduler.slaveLost(slaveId, SlaveLost(reason)) + scheduler.executorLost(executorId, SlaveLost(reason)) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala index aa097fd3a2..b41e951be9 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala @@ -5,7 +5,7 @@ import spark.util.SerializableBuffer private[spark] class TaskDescription( val taskId: Long, - val slaveId: String, + val executorId: String, val name: String, _serializedTask: ByteBuffer) extends Serializable { diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala index ca84503780..0f975ce1eb 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala @@ -4,7 +4,12 @@ package spark.scheduler.cluster * Information about a running task attempt inside a TaskSet. */ private[spark] -class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: String) { +class TaskInfo( + val taskId: Long, + val index: Int, + val launchTime: Long, + val executorId: String, + val host: String) { var finishTime: Long = 0 var failed = false diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index a089b71644..26201ad0dd 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -138,10 +138,11 @@ private[spark] class TaskSetManager( // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the // task must have a preference for this host (or no preferred locations at all). def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { + val hostsAlive = sched.hostsAlive speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set val localTask = speculatableTasks.find { index => - val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive + val locations = tasks(index).preferredLocations.toSet & hostsAlive val attemptLocs = taskAttempts(index).map(_.host) (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host) } @@ -189,7 +190,7 @@ private[spark] class TaskSetManager( } // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[TaskDescription] = { + def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { val time = System.currentTimeMillis val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) @@ -206,11 +207,11 @@ private[spark] class TaskSetManager( } else { "non-preferred, not one of " + task.preferredLocations.mkString(", ") } - logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( - taskSet.id, index, taskId, slaveId, host, prefStr)) + logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( + taskSet.id, index, taskId, execId, host, prefStr)) // Do various bookkeeping copiesRunning(index) += 1 - val info = new TaskInfo(taskId, index, time, host) + val info = new TaskInfo(taskId, index, time, execId, host) taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) if (preferred) { @@ -224,7 +225,7 @@ private[spark] class TaskSetManager( logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) val taskName = "task %s:%d".format(taskSet.id, index) - return Some(new TaskDescription(taskId, slaveId, taskName, serializedTask)) + return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) } case _ => } @@ -356,19 +357,22 @@ private[spark] class TaskSetManager( sched.taskSetFinished(this) } - def hostLost(hostname: String) { - logInfo("Re-queueing tasks for " + hostname + " from TaskSet " + taskSet.id) - // If some task has preferred locations only on hostname, put it in the no-prefs list - // to avoid the wait from delay scheduling - for (index <- getPendingTasksForHost(hostname)) { - val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive - if (newLocs.isEmpty) { - pendingTasksWithNoPrefs += index + def executorLost(execId: String, hostname: String) { + logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) + val newHostsAlive = sched.hostsAlive + // If some task has preferred locations only on hostname, and there are no more executors there, + // put it in the no-prefs list to avoid the wait from delay scheduling + if (!newHostsAlive.contains(hostname)) { + for (index <- getPendingTasksForHost(hostname)) { + val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive + if (newLocs.isEmpty) { + pendingTasksWithNoPrefs += index + } } } - // Re-enqueue any tasks that ran on the failed host if this is a shuffle map stage + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage if (tasks(0).isInstanceOf[ShuffleMapTask]) { - for ((tid, info) <- taskInfos if info.host == hostname) { + for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index if (finished(index)) { finished(index) = false @@ -382,7 +386,7 @@ private[spark] class TaskSetManager( } } // Also re-enqueue any tasks that were running on the node - for ((tid, info) <- taskInfos if info.running && info.host == hostname) { + for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { taskLost(tid, TaskState.KILLED, null) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala index 6b919d68b2..3c3afcbb14 100644 --- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala +++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala @@ -1,8 +1,8 @@ package spark.scheduler.cluster /** - * Represents free resources available on a worker node. + * Represents free resources available on an executor. */ private[spark] -class WorkerOffer(val slaveId: String, val hostname: String, val cores: Int) { +class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) { } diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index 2989e31f5e..f3467db86b 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -268,7 +268,7 @@ private[spark] class MesosSchedulerBackend( synchronized { slaveIdsWithExecutors -= slaveId.getValue } - scheduler.slaveLost(slaveId.getValue, reason) + scheduler.executorLost(slaveId.getValue, reason) } override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 19d35b8667..1215d5f5c8 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -30,6 +30,7 @@ extends Exception(message) private[spark] class BlockManager( + executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, val serializer: Serializer, @@ -68,8 +69,8 @@ class BlockManager( val connectionManager = new ConnectionManager(0) implicit val futureExecContext = connectionManager.futureExecContext - val connectionManagerId = connectionManager.id - val blockManagerId = BlockManagerId(connectionManagerId.host, connectionManagerId.port) + val blockManagerId = BlockManagerId( + executorId, connectionManager.id.host, connectionManager.id.port) // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) @@ -109,8 +110,9 @@ class BlockManager( /** * Construct a BlockManager with a memory limit set based on system properties. */ - def this(actorSystem: ActorSystem, master: BlockManagerMaster, serializer: Serializer) = { - this(actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) + def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, + serializer: Serializer) = { + this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) } /** diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index abb8b45a1f..f2f1e77d41 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -7,27 +7,32 @@ import java.util.concurrent.ConcurrentHashMap * This class represent an unique identifier for a BlockManager. * The first 2 constructors of this class is made private to ensure that * BlockManagerId objects can be created only using the factory method in - * [[spark.storage.BlockManager$]]. This allows de-duplication of id objects. + * [[spark.storage.BlockManager$]]. This allows de-duplication of ID objects. * Also, constructor parameters are private to ensure that parameters cannot * be modified from outside this class. */ private[spark] class BlockManagerId private ( + private var executorId_ : String, private var ip_ : String, private var port_ : Int ) extends Externalizable { - private def this() = this(null, 0) // For deserialization only + private def this() = this(null, null, 0) // For deserialization only - def ip = ip_ + def executorId: String = executorId_ - def port = port_ + def ip: String = ip_ + + def port: Int = port_ override def writeExternal(out: ObjectOutput) { + out.writeUTF(executorId_) out.writeUTF(ip_) out.writeInt(port_) } override def readExternal(in: ObjectInput) { + executorId_ = in.readUTF() ip_ = in.readUTF() port_ = in.readInt() } @@ -35,21 +40,23 @@ private[spark] class BlockManagerId private ( @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = "BlockManagerId(" + ip + ", " + port + ")" + override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port) - override def hashCode = ip.hashCode * 41 + port + override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port override def equals(that: Any) = that match { - case id: BlockManagerId => port == id.port && ip == id.ip - case _ => false + case id: BlockManagerId => + executorId == id.executorId && port == id.port && ip == id.ip + case _ => + false } } private[spark] object BlockManagerId { - def apply(ip: String, port: Int) = - getCachedBlockManagerId(new BlockManagerId(ip, port)) + def apply(execId: String, ip: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(execId, ip, port)) def apply(in: ObjectInput) = { val obj = new BlockManagerId() diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 937115e92c..55ff1dde9c 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -24,7 +24,7 @@ private[spark] class BlockManagerMaster( masterPort: Int) extends Logging { - val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt + val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager" @@ -45,10 +45,10 @@ private[spark] class BlockManagerMaster( } } - /** Remove a dead host from the master actor. This is only called on the master side. */ - def notifyADeadHost(host: String) { - tell(RemoveHost(host)) - logInfo("Removed " + host + " successfully in notifyADeadHost") + /** Remove a dead executor from the master actor. This is only called on the master side. */ + def removeExecutor(execId: String) { + tell(RemoveExecutor(execId)) + logInfo("Removed " + execId + " successfully in removeExecutor") } /** @@ -146,7 +146,7 @@ private[spark] class BlockManagerMaster( } var attempts = 0 var lastException: Exception = null - while (attempts < AKKA_RETRY_ATTEMPS) { + while (attempts < AKKA_RETRY_ATTEMPTS) { attempts += 1 try { val future = masterActor.ask(message)(timeout) diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index b31b6286d3..f88517f1a3 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -23,9 +23,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] - // Mapping from host name to block manager id. We allow multiple block managers - // on the same host name (ip). - private val blockManagerIdByHost = new HashMap[String, ArrayBuffer[BlockManagerId]] + // Mapping from executor ID to block manager ID. + private val blockManagerIdByExecutor = new HashMap[String, BlockManagerId] // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] @@ -74,8 +73,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { case RemoveBlock(blockId) => removeBlock(blockId) - case RemoveHost(host) => - removeHost(host) + case RemoveExecutor(execId) => + removeExecutor(execId) sender ! true case StopBlockManagerMaster => @@ -99,16 +98,12 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { def removeBlockManager(blockManagerId: BlockManagerId) { val info = blockManagerInfo(blockManagerId) - // Remove the block manager from blockManagerIdByHost. If the list of block - // managers belonging to the IP is empty, remove the entry from the hash map. - blockManagerIdByHost.get(blockManagerId.ip).foreach { managers: ArrayBuffer[BlockManagerId] => - managers -= blockManagerId - if (managers.size == 0) blockManagerIdByHost.remove(blockManagerId.ip) - } + // Remove the block manager from blockManagerIdByExecutor. + blockManagerIdByExecutor -= blockManagerId.executorId // Remove it from blockManagerInfo and remove all the blocks. blockManagerInfo.remove(blockManagerId) - var iterator = info.blocks.keySet.iterator + val iterator = info.blocks.keySet.iterator while (iterator.hasNext) { val blockId = iterator.next val locations = blockLocations.get(blockId)._2 @@ -133,17 +128,15 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { toRemove.foreach(removeBlockManager) } - def removeHost(host: String) { - logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.") - logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq) - blockManagerIdByHost.get(host).foreach(_.foreach(removeBlockManager)) - logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq) + def removeExecutor(execId: String) { + logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") + blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) sender ! true } def heartBeat(blockManagerId: BlockManagerId) { if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + if (blockManagerId.executorId == "" && !isLocal) { sender ! true } else { sender ! false @@ -188,24 +181,20 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! res } - private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " - - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { - logInfo("Got Register Msg from master node, don't register it") - } else if (!blockManagerInfo.contains(blockManagerId)) { - blockManagerIdByHost.get(blockManagerId.ip) match { - case Some(managers) => - // A block manager of the same host name already exists. - logInfo("Got another registration for host " + blockManagerId) - managers += blockManagerId + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + if (id.executorId == "" && !isLocal) { + // Got a register message from the master node; don't register it + } else if (!blockManagerInfo.contains(id)) { + blockManagerIdByExecutor.get(id.executorId) match { + case Some(manager) => + // A block manager of the same host name already exists + logError("Got two different block manager registrations on " + id.executorId) + System.exit(1) case None => - blockManagerIdByHost += (blockManagerId.ip -> ArrayBuffer(blockManagerId)) + blockManagerIdByExecutor(id.executorId) = id } - - blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo( - blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor)) + blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo( + id, System.currentTimeMillis(), maxMemSize, slaveActor) } sender ! true } @@ -217,11 +206,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { memSize: Long, diskSize: Long) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " + blockId + " " - if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + if (blockManagerId.executorId == "" && !isLocal) { // We intentionally do not register the master (except in local mode), // so we should not indicate failure. sender ! true @@ -353,8 +339,8 @@ object BlockManagerMasterActor { _lastSeenMs = System.currentTimeMillis() } - def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long) - : Unit = synchronized { + def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, + diskSize: Long) { updateLastSeenMs() diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index 3d03ff3a93..1494f90103 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -88,7 +88,7 @@ private[spark] case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster private[spark] -case class RemoveHost(host: String) extends ToBlockManagerMaster +case class RemoveExecutor(execId: String) extends ToBlockManagerMaster private[spark] case object StopBlockManagerMaster extends ToBlockManagerMaster diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 1003cc7a61..b7423c7234 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -11,6 +11,7 @@ import cc.spray.typeconversion.TwirlSupport._ import scala.collection.mutable.ArrayBuffer import spark.{Logging, SparkContext, SparkEnv} import spark.util.AkkaUtils +import spark.Utils private[spark] @@ -20,10 +21,10 @@ object BlockManagerUI extends Logging { def start(actorSystem : ActorSystem, masterActor: ActorRef, sc: SparkContext) { val webUIDirectives = new BlockManagerUIDirectives(actorSystem, masterActor, sc) try { - logInfo("Starting BlockManager WebUI.") - val port = Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt - AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, + val boundPort = AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", + Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt, webUIDirectives.handler, "BlockManagerHTTPServer") + logInfo("Started BlockManager web UI at %s:%d".format(Utils.localHostName(), boundPort)) } catch { case e: Exception => logError("Failed to create BlockManager WebUI", e) diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala index 689f07b969..f04c046c31 100644 --- a/core/src/main/scala/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/spark/storage/ThreadingTest.scala @@ -78,7 +78,8 @@ private[spark] object ThreadingTest { val masterIp: String = System.getProperty("spark.master.host", "localhost") val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort) - val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024) + val blockManager = new BlockManager( + "", actorSystem, blockManagerMaster, serializer, 1024 * 1024) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index ff2c3079be..775ff8f1aa 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -52,10 +52,10 @@ private[spark] object AkkaUtils { /** * Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to - * handle requests. Throws a SparkException if this fails. + * handle requests. Returns the bound port or throws a SparkException on failure. */ def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, - name: String = "HttpServer") { + name: String = "HttpServer"): Int = { val ioWorker = new IoWorker(actorSystem).start() val httpService = actorSystem.actorOf(Props(new HttpService(route))) val rootService = actorSystem.actorOf(Props(new SprayCanRootService(httpService))) @@ -67,7 +67,7 @@ private[spark] object AkkaUtils { try { Await.result(future, timeout) match { case bound: HttpServer.Bound => - return + return bound.endpoint.getPort case other: Any => throw new SparkException("Failed to bind web UI to port " + port + ": " + other) } diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index bb7c5c01c8..188f8910da 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -63,9 +63,9 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() - override def size(): Int = internalMap.size() + override def size: Int = internalMap.size - override def foreach[U](f: ((A, B)) => U): Unit = { + override def foreach[U](f: ((A, B)) => U) { val iterator = internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala index 70a7c8bc2f..342610e1dd 100644 --- a/core/src/test/scala/spark/DriverSuite.scala +++ b/core/src/test/scala/spark/DriverSuite.scala @@ -13,7 +13,8 @@ class DriverSuite extends FunSuite with Timeouts { val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) forAll(masters) { (master: String) => failAfter(10 seconds) { - Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME"))) + Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), + new File(System.getenv("SPARK_HOME"))) } } } @@ -28,4 +29,4 @@ object DriverWithoutCleanup { val sc = new SparkContext(args(0), "DriverWithoutCleanup") sc.parallelize(1 to 100, 4).count() } -} \ No newline at end of file +} diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 7d5305f1e0..e8fe7ecabc 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -43,13 +43,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((BlockManagerId("hostA", 1000), size1000), - (BlockManagerId("hostB", 1000), size10000))) + assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), + (BlockManagerId("b", "hostB", 1000), size10000))) tracker.stop() } @@ -61,47 +61,52 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) // As if we had two simulatenous fetch failures - tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) - tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) - // The remaining reduce task might try to grab the output dispite the shuffle failure; + // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } } test("remote fetch") { - val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("test", "localhost", 0) - System.setProperty("spark.master.port", boundPort.toString) - val masterTracker = new MapOutputTracker(actorSystem, true) - val slaveTracker = new MapOutputTracker(actorSystem, false) - masterTracker.registerShuffle(10, 1) - masterTracker.incrementGeneration() - slaveTracker.updateGeneration(masterTracker.getGeneration) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + try { + System.clearProperty("spark.master.host") // In case some previous test had set it + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem("test", "localhost", 0) + System.setProperty("spark.master.port", boundPort.toString) + val masterTracker = new MapOutputTracker(actorSystem, true) + val slaveTracker = new MapOutputTracker(actorSystem, false) + masterTracker.registerShuffle(10, 1) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("hostA", 1000), Array(compressedSize1000))) - masterTracker.incrementGeneration() - slaveTracker.updateGeneration(masterTracker.getGeneration) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("hostA", 1000), size1000))) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000), size1000))) - masterTracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) - masterTracker.incrementGeneration() - slaveTracker.updateGeneration(masterTracker.getGeneration) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - // failure should be cached - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + // failure should be cached + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + } finally { + System.clearProperty("spark.master.port") + } } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 2165744689..2d177bbf67 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -86,9 +86,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("BlockManagerId object caching") { - val id1 = BlockManagerId("XXX", 1) - val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1 - val id3 = BlockManagerId("XXX", 2) // this should return a different object + val id1 = BlockManagerId("e1", "XXX", 1) + val id2 = BlockManagerId("e1", "XXX", 1) // this should return the same object as id1 + val id3 = BlockManagerId("e1", "XXX", 2) // this should return a different object assert(id2 === id1, "id2 is not same as id1") assert(id2.eq(id1), "id2 is not the same object as id1") assert(id3 != id1, "id3 is same as id1") @@ -103,7 +103,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 1 manager interaction") { - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -133,8 +133,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 2 managers interaction") { - store = new BlockManager(actorSystem, master, serializer, 2000) - store2 = new BlockManager(actorSystem, master, new KryoSerializer, 2000) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000) + store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -149,7 +149,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing block") { - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -198,7 +198,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -206,7 +206,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.getSingle("a1") != None, "a1 was not in store") assert(master.getLocations("a1").size > 0, "master was not told about a1") - master.notifyADeadHost(store.blockManagerId.ip) + master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") store invokePrivate heartBeat() @@ -214,14 +214,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("reregistration on block update") { - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) assert(master.getLocations("a1").size > 0, "master was not told about a1") - master.notifyADeadHost(store.blockManagerId.ip) + master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY) @@ -233,35 +233,35 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration doesn't dead lock") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) // try many times to trigger any deadlocks for (i <- 1 to 100) { - master.notifyADeadHost(store.blockManagerId.ip) + master.removeExecutor(store.blockManagerId.executorId) val t1 = new Thread { - override def run = { + override def run() { store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true) } } val t2 = new Thread { - override def run = { + override def run() { store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) } } val t3 = new Thread { - override def run = { + override def run() { store invokePrivate heartBeat() } } - t1.start - t2.start - t3.start - t1.join - t2.join - t3.join + t1.start() + t2.start() + t3.start() + t1.join() + t2.join() + t3.join() store.dropFromMemory("a1", null) store.dropFromMemory("a2", null) @@ -270,7 +270,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -289,7 +289,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage with serialization") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -308,14 +308,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of same RDD") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY) store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY) store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY) - // Even though we accessed rdd_0_3 last, it should not have replaced partitiosn 1 and 2 + // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2 // from the same RDD assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") @@ -327,7 +327,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) @@ -350,7 +350,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("on-disk storage") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -363,7 +363,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -378,7 +378,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with getLocalBytes") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -393,7 +393,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -408,7 +408,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -423,7 +423,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -448,7 +448,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU with streams") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -472,7 +472,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels and streams") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -518,7 +518,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("overly large block") { - store = new BlockManager(actorSystem, master, serializer, 500) + store = new BlockManager("", actorSystem, master, serializer, 500) store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) @@ -529,49 +529,49 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block compression") { try { System.setProperty("spark.shuffle.compress", "true") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000) store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed") store.stop() store = null System.setProperty("spark.shuffle.compress", "false") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec2", actorSystem, master, serializer, 2000) store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed") store.stop() store = null System.setProperty("spark.broadcast.compress", "true") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec3", actorSystem, master, serializer, 2000) store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed") store.stop() store = null System.setProperty("spark.broadcast.compress", "false") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec4", actorSystem, master, serializer, 2000) store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed") store.stop() store = null System.setProperty("spark.rdd.compress", "true") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec5", actorSystem, master, serializer, 2000) store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed") store.stop() store = null System.setProperty("spark.rdd.compress", "false") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec6", actorSystem, master, serializer, 2000) store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec7", actorSystem, master, serializer, 2000) store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") store.stop() diff --git a/sbt/sbt b/sbt/sbt index a3055c13c1..8f426d18e8 100755 --- a/sbt/sbt +++ b/sbt/sbt @@ -5,4 +5,4 @@ if [ "$MESOS_HOME" != "" ]; then fi export SPARK_HOME=$(cd "$(dirname $0)/.."; pwd) export SPARK_TESTING=1 # To put test classes on classpath -java -Xmx1200M -XX:MaxPermSize=200m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@" +java -Xmx1200M -XX:MaxPermSize=250m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@" -- cgit v1.2.3 From 909850729ec59b788645575fdc03df7cc51fe42b Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 27 Jan 2013 23:17:20 -0800 Subject: Rename more things from slave to executor --- .../scala/spark/deploy/worker/ExecutorRunner.scala | 2 +- .../spark/executor/StandaloneExecutorBackend.scala | 12 +++--- .../spark/scheduler/cluster/SlaveResources.scala | 4 -- .../cluster/SparkDeploySchedulerBackend.scala | 16 ++------ .../cluster/StandaloneClusterMessage.scala | 16 ++++---- .../cluster/StandaloneSchedulerBackend.scala | 48 +++++++++++----------- .../main/scala/spark/storage/BlockManagerUI.scala | 2 + .../main/scala/spark/util/MetadataCleaner.scala | 10 ++--- 8 files changed, 50 insertions(+), 60 deletions(-) delete mode 100644 core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index af3acfecb6..f5ff267d44 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -65,7 +65,7 @@ private[spark] class ExecutorRunner( } } - /** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */ + /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => hostname diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 435ee5743e..50871802ea 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -8,10 +8,10 @@ import akka.actor.{ActorRef, Actor, Props} import java.util.concurrent.{TimeUnit, ThreadPoolExecutor, SynchronousQueue} import akka.remote.RemoteClientLifeCycleEvent import spark.scheduler.cluster._ -import spark.scheduler.cluster.RegisteredSlave +import spark.scheduler.cluster.RegisteredExecutor import spark.scheduler.cluster.LaunchTask -import spark.scheduler.cluster.RegisterSlaveFailed -import spark.scheduler.cluster.RegisterSlave +import spark.scheduler.cluster.RegisterExecutorFailed +import spark.scheduler.cluster.RegisterExecutor private[spark] class StandaloneExecutorBackend( @@ -30,7 +30,7 @@ private[spark] class StandaloneExecutorBackend( try { logInfo("Connecting to master: " + masterUrl) master = context.actorFor(masterUrl) - master ! RegisterSlave(executorId, hostname, cores) + master ! RegisterExecutor(executorId, hostname, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing } catch { @@ -41,11 +41,11 @@ private[spark] class StandaloneExecutorBackend( } override def receive = { - case RegisteredSlave(sparkProperties) => + case RegisteredExecutor(sparkProperties) => logInfo("Successfully registered with master") executor.initialize(executorId, hostname, sparkProperties) - case RegisterSlaveFailed(message) => + case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) System.exit(1) diff --git a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala b/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala deleted file mode 100644 index 96ebaa4601..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala +++ /dev/null @@ -1,4 +0,0 @@ -package spark.scheduler.cluster - -private[spark] -class SlaveResources(val slaveId: String, val hostname: String, val coresFree: Int) {} diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index f0792c1b76..6dd3ae003d 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,7 +19,6 @@ private[spark] class SparkDeploySchedulerBackend( var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt - val executorIdToSlaveId = new HashMap[String, String] // Memory used by each executor (in megabytes) val executorMemory = { @@ -47,7 +46,7 @@ private[spark] class SparkDeploySchedulerBackend( } override def stop() { - stopping = true; + stopping = true super.stop() client.stop() if (shutdownCallback != null) { @@ -67,23 +66,16 @@ private[spark] class SparkDeploySchedulerBackend( } def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) { - executorIdToSlaveId += id -> workerId logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format( id, host, cores, Utils.memoryMegabytesToString(memory))) } - def executorRemoved(id: String, message: String, exitStatus: Option[Int]) { + def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) { val reason: ExecutorLossReason = exitStatus match { case Some(code) => ExecutorExited(code) case None => SlaveLost(message) } - logInfo("Executor %s removed: %s".format(id, message)) - executorIdToSlaveId.get(id) match { - case Some(slaveId) => - executorIdToSlaveId.remove(id) - scheduler.executorLost(slaveId, reason) - case None => - logInfo("No slave ID known for executor %s".format(id)) - } + logInfo("Executor %s removed: %s".format(executorId, message)) + scheduler.executorLost(executorId, reason) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala index 1386cd9d44..c68f15bdfa 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -11,24 +11,26 @@ private[spark] case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage private[spark] -case class RegisteredSlave(sparkProperties: Seq[(String, String)]) extends StandaloneClusterMessage +case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) + extends StandaloneClusterMessage private[spark] -case class RegisterSlaveFailed(message: String) extends StandaloneClusterMessage +case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage -// Slaves to master +// Executors to master private[spark] -case class RegisterSlave(slaveId: String, host: String, cores: Int) extends StandaloneClusterMessage +case class RegisterExecutor(executorId: String, host: String, cores: Int) + extends StandaloneClusterMessage private[spark] -case class StatusUpdate(slaveId: String, taskId: Long, state: TaskState, data: SerializableBuffer) +case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer) extends StandaloneClusterMessage private[spark] object StatusUpdate { /** Alternate factory method that takes a ByteBuffer directly for the data field */ - def apply(slaveId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = { - StatusUpdate(slaveId, taskId, state, new SerializableBuffer(data)) + def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = { + StatusUpdate(executorId, taskId, state, new SerializableBuffer(data)) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 32be1e7a26..69822f568c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -24,9 +24,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor var totalCoreCount = new AtomicInteger(0) class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor { - val slaveActor = new HashMap[String, ActorRef] - val slaveAddress = new HashMap[String, Address] - val slaveHost = new HashMap[String, String] + val executorActor = new HashMap[String, ActorRef] + val executorAddress = new HashMap[String, Address] + val executorHost = new HashMap[String, String] val freeCores = new HashMap[String, Int] val actorToExecutorId = new HashMap[ActorRef, String] val addressToExecutorId = new HashMap[Address, String] @@ -37,17 +37,17 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } def receive = { - case RegisterSlave(executorId, host, cores) => - if (slaveActor.contains(executorId)) { - sender ! RegisterSlaveFailed("Duplicate executor ID: " + executorId) + case RegisterExecutor(executorId, host, cores) => + if (executorActor.contains(executorId)) { + sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) } else { logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredSlave(sparkProperties) + sender ! RegisteredExecutor(sparkProperties) context.watch(sender) - slaveActor(executorId) = sender - slaveHost(executorId) = host + executorActor(executorId) = sender + executorHost(executorId) = host freeCores(executorId) = cores - slaveAddress(executorId) = sender.path.address + executorAddress(executorId) = sender.path.address actorToExecutorId(sender) = executorId addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) @@ -69,45 +69,45 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor context.stop(self) case Terminated(actor) => - actorToExecutorId.get(actor).foreach(removeSlave(_, "Akka actor terminated")) + actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated")) case RemoteClientDisconnected(transport, address) => - addressToExecutorId.get(address).foreach(removeSlave(_, "remote Akka client disconnected")) + addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected")) case RemoteClientShutdown(transport, address) => - addressToExecutorId.get(address).foreach(removeSlave(_, "remote Akka client shutdown")) + addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown")) } - // Make fake resource offers on all slaves + // Make fake resource offers on all executors def makeOffers() { launchTasks(scheduler.resourceOffers( - slaveHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) + executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) } - // Make fake resource offers on just one slave + // Make fake resource offers on just one executor def makeOffers(executorId: String) { launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, slaveHost(executorId), freeCores(executorId))))) + Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId))))) } // Launch tasks returned by a set of resource offers def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { freeCores(task.executorId) -= 1 - slaveActor(task.executorId) ! LaunchTask(task) + executorActor(task.executorId) ! LaunchTask(task) } } // Remove a disconnected slave from the cluster - def removeSlave(executorId: String, reason: String) { + def removeExecutor(executorId: String, reason: String) { logInfo("Slave " + executorId + " disconnected, so removing it") val numCores = freeCores(executorId) - actorToExecutorId -= slaveActor(executorId) - addressToExecutorId -= slaveAddress(executorId) - slaveActor -= executorId - slaveHost -= executorId + actorToExecutorId -= executorActor(executorId) + addressToExecutorId -= executorAddress(executorId) + executorActor -= executorId + executorHost -= executorId freeCores -= executorId - slaveHost -= executorId + executorHost -= executorId totalCoreCount.addAndGet(-numCores) scheduler.executorLost(executorId, SlaveLost(reason)) } diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index b7423c7234..956ede201e 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -21,6 +21,8 @@ object BlockManagerUI extends Logging { def start(actorSystem : ActorSystem, masterActor: ActorRef, sc: SparkContext) { val webUIDirectives = new BlockManagerUIDirectives(actorSystem, masterActor, sc) try { + // TODO: This needs to find a random free port to bind to. Unfortunately, there's no way + // in spray to do that, so we'll have to rely on something like new ServerSocket() val boundPort = AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt, webUIDirectives.handler, "BlockManagerHTTPServer") diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 139e21d09e..721c4c6029 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -14,18 +14,16 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging val task = new TimerTask { def run() { try { - if (delaySeconds > 0) { - cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) - logInfo("Ran metadata cleaner for " + name) - } + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) + logInfo("Ran metadata cleaner for " + name) } catch { case e: Exception => logError("Error running cleanup task for " + name, e) } } } - if (periodSeconds > 0) { - logInfo( + if (delaySeconds > 0) { + logDebug( "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and " + "period of " + periodSeconds + " secs") timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) -- cgit v1.2.3 From f03d9760fd8ac67fd0865cb355ba75d2eff507fe Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 27 Jan 2013 23:56:14 -0800 Subject: Clean up BlockManagerUI a little (make it not be an object, merge with Directives, and bind to a random port) --- core/src/main/scala/spark/SparkContext.scala | 7 +- core/src/main/scala/spark/Utils.scala | 17 ++- .../scala/spark/deploy/master/MasterWebUI.scala | 6 +- .../scala/spark/deploy/worker/WorkerWebUI.scala | 6 +- .../main/scala/spark/storage/BlockManagerUI.scala | 120 ++++++++++----------- core/src/main/scala/spark/util/AkkaUtils.scala | 6 +- .../main/scala/spark/util/MetadataCleaner.scala | 3 + 7 files changed, 91 insertions(+), 74 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 39721b47ae..77036c1275 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -44,6 +44,7 @@ import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} +import storage.BlockManagerUI import util.{MetadataCleaner, TimeStampedHashMap} /** @@ -88,8 +89,9 @@ class SparkContext( SparkEnv.set(env) // Start the BlockManager UI - spark.storage.BlockManagerUI.start(SparkEnv.get.actorSystem, - SparkEnv.get.blockManager.master.masterActor, this) + private[spark] val ui = new BlockManagerUI( + env.actorSystem, env.blockManager.master.masterActor, this) + ui.start() // Used to store a URL for each static file/jar together with the file's local timestamp private[spark] val addedFiles = HashMap[String, Long]() @@ -97,7 +99,6 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() - private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index ae77264372..1e58d01273 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,7 +1,7 @@ package spark import java.io._ -import java.net.{NetworkInterface, InetAddress, Inet4Address, URL, URI} +import java.net._ import java.util.{Locale, Random, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration @@ -11,6 +11,7 @@ import scala.collection.JavaConversions._ import scala.io.Source import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder +import scala.Some /** * Various utility methods used by Spark. @@ -431,4 +432,18 @@ private object Utils extends Logging { } "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine) } + + /** + * Try to find a free port to bind to on the local host. This should ideally never be needed, + * except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray) + * don't let users bind to port 0 and then figure out which free port they actually bound to. + * We work around this by binding a ServerSocket and immediately unbinding it. This is *not* + * necessarily guaranteed to work, but it's the best we can do. + */ + def findFreePort(): Int = { + val socket = new ServerSocket(0) + val portBound = socket.getLocalPort + socket.close() + portBound + } } diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 458ee2d665..a01774f511 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -14,12 +14,15 @@ import cc.spray.typeconversion.SprayJsonSupport._ import spark.deploy._ import spark.deploy.JsonProtocol._ +/** + * Web UI server for the standalone master. + */ private[spark] class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/master/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(1 seconds) + implicit val timeout = Timeout(10 seconds) val handler = { get { @@ -76,5 +79,4 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct getFromResourceDirectory(RESOURCE_DIR) } } - } diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index f9489d99fc..ef81f072a3 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -13,12 +13,15 @@ import cc.spray.typeconversion.SprayJsonSupport._ import spark.deploy.{WorkerState, RequestWorkerState} import spark.deploy.JsonProtocol._ +/** + * Web UI server for the standalone worker. + */ private[spark] class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/worker/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(1 seconds) + implicit val timeout = Timeout(10 seconds) val handler = { get { @@ -50,5 +53,4 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct getFromResourceDirectory(RESOURCE_DIR) } } - } diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 956ede201e..eda320fa47 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -1,32 +1,41 @@ package spark.storage import akka.actor.{ActorRef, ActorSystem} -import akka.dispatch.Await import akka.pattern.ask import akka.util.Timeout import akka.util.duration._ -import cc.spray.Directives import cc.spray.directives._ import cc.spray.typeconversion.TwirlSupport._ +import cc.spray.Directives import scala.collection.mutable.ArrayBuffer -import spark.{Logging, SparkContext, SparkEnv} +import spark.{Logging, SparkContext} import spark.util.AkkaUtils import spark.Utils +/** + * Web UI server for the BlockManager inside each SparkContext. + */ private[spark] -object BlockManagerUI extends Logging { +class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, sc: SparkContext) + extends Directives with Logging { + + val STATIC_RESOURCE_DIR = "spark/deploy/static" + + implicit val timeout = Timeout(10 seconds) - /* Starts the Web interface for the BlockManager */ - def start(actorSystem : ActorSystem, masterActor: ActorRef, sc: SparkContext) { - val webUIDirectives = new BlockManagerUIDirectives(actorSystem, masterActor, sc) + /** Start a HTTP server to run the Web interface */ + def start() { try { - // TODO: This needs to find a random free port to bind to. Unfortunately, there's no way - // in spray to do that, so we'll have to rely on something like new ServerSocket() - val boundPort = AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", - Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt, - webUIDirectives.handler, "BlockManagerHTTPServer") - logInfo("Started BlockManager web UI at %s:%d".format(Utils.localHostName(), boundPort)) + val port = if (System.getProperty("spark.ui.port") != null) { + System.getProperty("spark.ui.port").toInt + } else { + // TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which + // random port it bound to, so we have to try to find a local one by creating a socket. + Utils.findFreePort() + } + AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler, "BlockManagerHTTPServer") + logInfo("Started BlockManager web UI at http://%s:%d".format(Utils.localHostName(), port)) } catch { case e: Exception => logError("Failed to create BlockManager WebUI", e) @@ -34,58 +43,43 @@ object BlockManagerUI extends Logging { } } -} - - -private[spark] -class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, - sc: SparkContext) extends Directives { - - val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(1 seconds) - val handler = { - - get { path("") { completeWith { - // Request the current storage status from the Master - val future = master ? GetStorageStatus - future.map { status => - val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray - - // Calculate macro-level statistics - val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) - val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) - val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)) - .reduceOption(_+_).getOrElse(0L) - - val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) - - spark.storage.html.index. - render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) - } - }}} ~ - get { path("rdd") { parameter("id") { id => { completeWith { - val future = master ? GetStorageStatus - future.map { status => - val prefix = "rdd_" + id.toString - - - val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray - val filteredStorageStatusList = StorageUtils. - filterStorageStatusByPrefix(storageStatusList, prefix) - - val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head - - spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) - + get { + path("") { + completeWith { + // Request the current storage status from the Master + val future = blockManagerMaster ? GetStorageStatus + future.map { status => + // Calculate macro-level statistics + val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray + val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) + val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) + val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)) + .reduceOption(_+_).getOrElse(0L) + val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) + spark.storage.html.index. + render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) + } + } + } ~ + path("rdd") { + parameter("id") { id => + completeWith { + val future = blockManagerMaster ? GetStorageStatus + future.map { status => + val prefix = "rdd_" + id.toString + val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray + val filteredStorageStatusList = StorageUtils. + filterStorageStatusByPrefix(storageStatusList, prefix) + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head + spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) + } + } + } + } ~ + pathPrefix("static") { + getFromResourceDirectory(STATIC_RESOURCE_DIR) } - }}}}} ~ - pathPrefix("static") { - getFromResourceDirectory(STATIC_RESOURCE_DIR) } - } - - - } diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 775ff8f1aa..e0fdeffbc4 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -1,6 +1,6 @@ package spark.util -import akka.actor.{Props, ActorSystemImpl, ActorSystem} +import akka.actor.{ActorRef, Props, ActorSystemImpl, ActorSystem} import com.typesafe.config.ConfigFactory import akka.util.duration._ import akka.pattern.ask @@ -55,7 +55,7 @@ private[spark] object AkkaUtils { * handle requests. Returns the bound port or throws a SparkException on failure. */ def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, - name: String = "HttpServer"): Int = { + name: String = "HttpServer"): ActorRef = { val ioWorker = new IoWorker(actorSystem).start() val httpService = actorSystem.actorOf(Props(new HttpService(route))) val rootService = actorSystem.actorOf(Props(new SprayCanRootService(httpService))) @@ -67,7 +67,7 @@ private[spark] object AkkaUtils { try { Await.result(future, timeout) match { case bound: HttpServer.Bound => - return bound.endpoint.getPort + return server case other: Any => throw new SparkException("Failed to bind web UI to port " + port + ": " + other) } diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 721c4c6029..51fb440108 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -5,6 +5,9 @@ import java.util.{TimerTask, Timer} import spark.Logging +/** + * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) + */ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { val delaySeconds = MetadataCleaner.getDelaySeconds -- cgit v1.2.3 From 286f8f876ff495df33a7966e77ca90d69f338450 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 28 Jan 2013 01:29:27 -0800 Subject: Change time unit in MetadataCleaner to seconds --- core/src/main/scala/spark/util/MetadataCleaner.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 51fb440108..6cf93a9b17 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -9,7 +9,6 @@ import spark.Logging * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) */ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delaySeconds = MetadataCleaner.getDelaySeconds val periodSeconds = math.max(10, delaySeconds / 10) val timer = new Timer(name + " cleanup timer", true) @@ -39,7 +38,7 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging object MetadataCleaner { - def getDelaySeconds = (System.getProperty("spark.cleaner.delay", "-100").toDouble * 60).toInt - def setDelaySeconds(delay: Long) { System.setProperty("spark.cleaner.delay", delay.toString) } + def getDelaySeconds = System.getProperty("spark.cleaner.delay", "-1").toInt + def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.delay", delay.toString) } } -- cgit v1.2.3 From 07f568e1bfc67eead88e2c5dbfb9cac23e1ac8bc Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 24 Jan 2013 15:27:29 -0800 Subject: SPARK-658: Adding logging of stage duration --- .../main/scala/spark/scheduler/DAGScheduler.scala | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index bd541d4207..8aad667182 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -86,6 +86,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] + val stageSubmissionTimes = new HashMap[Stage, Long] val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) @@ -393,6 +394,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with logDebug("New pending tasks: " + myPending) taskSched.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority)) + if (!stageSubmissionTimes.contains(stage)) { + stageSubmissionTimes.put(stage, System.currentTimeMillis()) + } } else { logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) @@ -407,6 +411,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def handleTaskCompletion(event: CompletionEvent) { val task = event.task val stage = idToStage(task.stageId) + + def stageFinished(stage: Stage) = { + val serviceTime = stageSubmissionTimes.remove(stage) match { + case Some(t) => (System.currentTimeMillis() - t).toString + case _ => "Unkown" + } + logInfo("%s (%s) finished in %s ms".format(stage, stage.origin, serviceTime)) + running -= stage + } event.reason match { case Success => logInfo("Completed " + task) @@ -421,13 +434,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (!job.finished(rt.outputId)) { job.finished(rt.outputId) = true job.numFinished += 1 - job.listener.taskSucceeded(rt.outputId, event.result) // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { activeJobs -= job resultStageToJob -= stage - running -= stage + stageFinished(stage) } + job.listener.taskSucceeded(rt.outputId, event.result) } case None => logInfo("Ignoring result from " + rt + " because its job has finished") @@ -444,8 +457,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with stage.addOutputLoc(smt.partition, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { - logInfo(stage + " (" + stage.origin + ") finished; looking for newly runnable stages") - running -= stage + stageFinished(stage) + logInfo("looking for newly runnable stages") logInfo("running: " + running) logInfo("waiting: " + waiting) logInfo("failed: " + failed) -- cgit v1.2.3 From c423be7d8e1349fc00431328b76b52f4eee8a975 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 24 Jan 2013 18:25:57 -0800 Subject: Renaming stage finished function --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 8aad667182..bce7418e87 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -412,7 +412,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val task = event.task val stage = idToStage(task.stageId) - def stageFinished(stage: Stage) = { + def markStageAsFinished(stage: Stage) = { val serviceTime = stageSubmissionTimes.remove(stage) match { case Some(t) => (System.currentTimeMillis() - t).toString case _ => "Unkown" @@ -438,7 +438,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (job.numFinished == job.numPartitions) { activeJobs -= job resultStageToJob -= stage - stageFinished(stage) + markStageAsFinished(stage) } job.listener.taskSucceeded(rt.outputId, event.result) } @@ -457,7 +457,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with stage.addOutputLoc(smt.partition, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { - stageFinished(stage) + markStageAsFinished(stage) logInfo("looking for newly runnable stages") logInfo("running: " + running) logInfo("waiting: " + waiting) -- cgit v1.2.3 From 501433f1d59b1b326c0a7169fa1fd6136f7628e3 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 28 Jan 2013 10:17:35 -0800 Subject: Making submission time a field --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 7 +++---- core/src/main/scala/spark/scheduler/Stage.scala | 3 +++ 2 files changed, 6 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index bce7418e87..7ba1f3430a 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -86,7 +86,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] - val stageSubmissionTimes = new HashMap[Stage, Long] val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) @@ -394,8 +393,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with logDebug("New pending tasks: " + myPending) taskSched.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority)) - if (!stageSubmissionTimes.contains(stage)) { - stageSubmissionTimes.put(stage, System.currentTimeMillis()) + if (!stage.submissionTime.isDefined) { + stage.submissionTime = Some(System.currentTimeMillis()) } } else { logDebug("Stage " + stage + " is actually done; %b %d %d".format( @@ -413,7 +412,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val stage = idToStage(task.stageId) def markStageAsFinished(stage: Stage) = { - val serviceTime = stageSubmissionTimes.remove(stage) match { + val serviceTime = stage.submissionTime match { case Some(t) => (System.currentTimeMillis() - t).toString case _ => "Unkown" } diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index e9419728e3..374114d870 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -32,6 +32,9 @@ private[spark] class Stage( val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) var numAvailableOutputs = 0 + /** When first task was submitted to scheduler. */ + var submissionTime: Option[Long] = None + private var nextAttemptId = 0 def isAvailable: Boolean = { -- cgit v1.2.3 From efff7bfb3382f4e07f9fad0e6e647c0ec629355e Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 28 Jan 2013 20:23:11 -0800 Subject: add long and float accumulatorparams --- core/src/main/scala/spark/SparkContext.scala | 10 ++++++++++ core/src/test/scala/spark/AccumulatorSuite.scala | 6 ++++++ 2 files changed, 16 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 77036c1275..dc9b8688b3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -673,6 +673,16 @@ object SparkContext { def zero(initialValue: Int) = 0 } + implicit object LongAccumulatorParam extends AccumulatorParam[Long] { + def addInPlace(t1: Long, t2: Long) = t1 + t2 + def zero(initialValue: Long) = 0l + } + + implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { + def addInPlace(t1: Float, t2: Float) = t1 + t2 + def zero(initialValue: Float) = 0f + } + // TODO: Add AccumulatorParams for other types, e.g. lists and strings implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala index 78d64a44ae..ac8ae7d308 100644 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -17,6 +17,12 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte val d = sc.parallelize(1 to 20) d.foreach{x => acc += x} acc.value should be (210) + + + val longAcc = sc.accumulator(0l) + val maxInt = Integer.MAX_VALUE.toLong + d.foreach{x => longAcc += maxInt + x} + longAcc.value should be (210l + maxInt * 20) } test ("value not assignable from tasks") { -- cgit v1.2.3 From 1f9b486a8be49ef547ac1532cafd63c4c9d4ddda Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 28 Jan 2013 20:24:54 -0800 Subject: Some DEBUG-level log cleanup. A few changes to make the DEBUG-level logs less noisy and more readable. - Moved a few very frequent messages to Trace - Changed some BlockManger log messages to make them more understandable SPARK-666 #resolve --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 8 ++++---- core/src/main/scala/spark/storage/BlockManager.scala | 14 +++++++------- .../main/scala/spark/storage/BlockManagerMasterActor.scala | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index bd541d4207..f10d7cc84e 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -308,10 +308,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } else { // TODO: We might want to run this less often, when we are sure that something has become // runnable that wasn't before. - logDebug("Checking for newly runnable parent stages") - logDebug("running: " + running) - logDebug("waiting: " + waiting) - logDebug("failed: " + failed) + logTrace("Checking for newly runnable parent stages") + logTrace("running: " + running) + logTrace("waiting: " + waiting) + logTrace("failed: " + failed) val waiting2 = waiting.toArray waiting.clear() for (stage <- waiting2.sortBy(_.priority)) { diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 1215d5f5c8..c61fd75c2b 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -243,7 +243,7 @@ class BlockManager( val startTimeMs = System.currentTimeMillis var managers = master.getLocations(blockId) val locations = managers.map(_.ip) - logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -253,7 +253,7 @@ class BlockManager( def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray - logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -645,7 +645,7 @@ class BlockManager( var size = 0L myInfo.synchronized { - logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") if (level.useMemory) { @@ -677,8 +677,10 @@ class BlockManager( } logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) + // Replicate block if required if (level.replication > 1) { + val remoteStartTime = System.currentTimeMillis // Serialize the block if not already done if (bytesAfterPut == null) { if (valuesAfterPut == null) { @@ -688,12 +690,10 @@ class BlockManager( bytesAfterPut = dataSerialize(blockId, valuesAfterPut) } replicate(blockId, bytesAfterPut, level) + logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime)) } - BlockManager.dispose(bytesAfterPut) - logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)) - return size } @@ -978,7 +978,7 @@ object BlockManager extends Logging { */ def dispose(buffer: ByteBuffer) { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - logDebug("Unmapping " + buffer) + logTrace("Unmapping " + buffer) if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { buffer.asInstanceOf[DirectBuffer].cleaner().clean() } diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index f88517f1a3..2830bc6297 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -115,7 +115,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } def expireDeadHosts() { - logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.") + logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.") val now = System.currentTimeMillis() val minSeenTime = now - slaveTimeout val toRemove = new HashSet[BlockManagerId] -- cgit v1.2.3 From 7ee824e42ebaa1fc0b0248e0a35021108625ed14 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 28 Jan 2013 21:48:32 -0800 Subject: Units from ms -> s --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 7ba1f3430a..b8336d9d06 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -413,10 +413,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def markStageAsFinished(stage: Stage) = { val serviceTime = stage.submissionTime match { - case Some(t) => (System.currentTimeMillis() - t).toString + case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0) case _ => "Unkown" } - logInfo("%s (%s) finished in %s ms".format(stage, stage.origin, serviceTime)) + logInfo("%s (%s) finished in %s s".format(stage, stage.origin, serviceTime)) running -= stage } event.reason match { -- cgit v1.2.3 From b45857c965219e2d26f35adb2ea3a2b831fdb77f Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 28 Jan 2013 23:56:56 -0600 Subject: Add RDD.toDebugString. Original idea by Nathan Kronenfeld. --- core/src/main/scala/spark/RDD.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 0d3857f9dd..172431c31a 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -638,4 +638,14 @@ abstract class RDD[T: ClassManifest]( protected[spark] def clearDependencies() { dependencies_ = null } + + /** A description of this RDD and its recursive dependencies for debugging. */ + def toDebugString(): String = { + def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = { + Seq(prefix + rdd) ++ rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) + } + debugString(this).mkString("\n") + } + + override def toString() = "%s[%d] at %s".format(getClass.getSimpleName, id, origin) } -- cgit v1.2.3 From 951cfd9ba2a9239a777f156f10af820e9df49606 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 29 Jan 2013 00:02:17 -0600 Subject: Add JavaRDDLike.toDebugString(). --- core/src/main/scala/spark/api/java/JavaRDDLike.scala | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 4c95c989b5..44f778e5c2 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -330,4 +330,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround case _ => Optional.absent() } } + + /** A description of this RDD and its recursive dependencies for debugging. */ + def toDebugString(): String = { + rdd.toDebugString() + } } -- cgit v1.2.3 From 3cda14af3fea97c2372c7335505e9dad7e0dd117 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 29 Jan 2013 00:12:31 -0600 Subject: Add number of splits. --- core/src/main/scala/spark/RDD.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 172431c31a..39bacd2afb 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -642,7 +642,8 @@ abstract class RDD[T: ClassManifest]( /** A description of this RDD and its recursive dependencies for debugging. */ def toDebugString(): String = { def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = { - Seq(prefix + rdd) ++ rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) + Seq(prefix + rdd + " (" + rdd.splits.size + " splits)") ++ + rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) } debugString(this).mkString("\n") } -- cgit v1.2.3 From cbf72bffa5874319c7ee7117a073e9d01fa51585 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 29 Jan 2013 00:20:36 -0600 Subject: Include name, if set, in RDD.toString(). --- core/src/main/scala/spark/RDD.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 39bacd2afb..a23441483e 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -648,5 +648,10 @@ abstract class RDD[T: ClassManifest]( debugString(this).mkString("\n") } - override def toString() = "%s[%d] at %s".format(getClass.getSimpleName, id, origin) + override def toString(): String = "%s%s[%d] at %s".format( + Option(name).map(_ + " ").getOrElse(""), + getClass.getSimpleName, + id, + origin) + } -- cgit v1.2.3 From 64ba6a8c2c5f46e6de6deb6a6fd576a55cb3b198 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 28 Jan 2013 22:30:12 -0800 Subject: Simplify checkpointing code and RDD class a little: - RDD's getDependencies and getSplits methods are now guaranteed to be called only once, so subclasses can safely do computation in there without worrying about caching the results. - The management of a "splits_" variable that is cleared out when we checkpoint an RDD is now done in the RDD class. - A few of the RDD subclasses are simpler. - CheckpointRDD's compute() method no longer assumes that it is given a CheckpointRDDSplit -- it can work just as well on a split from the original RDD, because it only looks at its index. This is important because things like UnionRDD and ZippedRDD remember the parent's splits as part of their own and wouldn't work on checkpointed parents. - RDD.iterator can now reuse cached data if an RDD is computed before it is checkpointed. It seems like it wouldn't do this before (it always called iterator() on the CheckpointRDD, which read from HDFS). --- core/src/main/scala/spark/CacheManager.scala | 6 +- core/src/main/scala/spark/PairRDDFunctions.scala | 4 +- core/src/main/scala/spark/RDD.scala | 130 ++++++++++++--------- core/src/main/scala/spark/RDDCheckpointData.scala | 19 +-- .../main/scala/spark/api/java/JavaRDDLike.scala | 2 +- core/src/main/scala/spark/rdd/CartesianRDD.scala | 12 +- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 61 +++++----- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 14 +-- core/src/main/scala/spark/rdd/MappedRDD.scala | 6 +- .../main/scala/spark/rdd/PartitionPruningRDD.scala | 13 +-- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 8 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 14 +-- core/src/main/scala/spark/rdd/ZippedRDD.scala | 7 +- .../main/scala/spark/util/MetadataCleaner.scala | 4 +- core/src/test/scala/spark/CheckpointSuite.scala | 21 ++-- 15 files changed, 153 insertions(+), 168 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala index a0b53fd9d6..711435c333 100644 --- a/core/src/main/scala/spark/CacheManager.scala +++ b/core/src/main/scala/spark/CacheManager.scala @@ -10,9 +10,9 @@ import spark.storage.{BlockManager, StorageLevel} private[spark] class CacheManager(blockManager: BlockManager) extends Logging { private val loading = new HashSet[String] - /** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */ + /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) - : Iterator[T] = { + : Iterator[T] = { val key = "rdd_%d_%d".format(rdd.id, split.index) logInfo("Cache key is " + key) blockManager.get(key) match { @@ -50,7 +50,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { // If we got here, we have to load the split val elements = new ArrayBuffer[Any] logInfo("Computing partition " + split) - elements ++= rdd.compute(split, context) + elements ++= rdd.computeOrReadCheckpoint(split, context) // Try to put this block in the blockManager blockManager.put(key, elements, storageLevel, true) return elements.iterator.asInstanceOf[Iterator[T]] diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 53b051f1c5..231e23a7de 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -649,9 +649,7 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( } private[spark] -class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) - extends RDD[(K, U)](prev) { - +class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev) { override def getSplits = firstParent[(K, V)].splits override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split, context: TaskContext) = diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 0d3857f9dd..dbad6d4c83 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -1,27 +1,17 @@ package spark -import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream} import java.net.URL import java.util.{Date, Random} import java.util.{HashMap => JHashMap} -import java.util.concurrent.atomic.AtomicLong import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import org.apache.hadoop.fs.Path import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapred.FileOutputCommitter -import org.apache.hadoop.mapred.HadoopWriter -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputCommitter -import org.apache.hadoop.mapred.OutputFormat -import org.apache.hadoop.mapred.SequenceFileOutputFormat import org.apache.hadoop.mapred.TextOutputFormat import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} @@ -30,7 +20,6 @@ import spark.partial.BoundedDouble import spark.partial.CountEvaluator import spark.partial.GroupedCountEvaluator import spark.partial.PartialResult -import spark.rdd.BlockRDD import spark.rdd.CartesianRDD import spark.rdd.FilteredRDD import spark.rdd.FlatMappedRDD @@ -73,11 +62,11 @@ import SparkContext._ * on RDD internals. */ abstract class RDD[T: ClassManifest]( - @transient var sc: SparkContext, - var dependencies_ : List[Dependency[_]] + @transient private var sc: SparkContext, + @transient private var deps: Seq[Dependency[_]] ) extends Serializable with Logging { - + /** Construct an RDD with just a one-to-one dependency on one parent */ def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) @@ -85,25 +74,27 @@ abstract class RDD[T: ClassManifest]( // Methods that should be implemented by subclasses of RDD // ======================================================================= - /** Function for computing a given partition. */ + /** Implemented by subclasses to compute a given partition. */ def compute(split: Split, context: TaskContext): Iterator[T] - /** Set of partitions in this RDD. */ - protected def getSplits(): Array[Split] + /** + * Implemented by subclasses to return the set of partitions in this RDD. This method will only + * be called once, so it is safe to implement a time-consuming computation in it. + */ + protected def getSplits: Array[Split] - /** How this RDD depends on any parent RDDs. */ - protected def getDependencies(): List[Dependency[_]] = dependencies_ + /** + * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only + * be called once, so it is safe to implement a time-consuming computation in it. + */ + protected def getDependencies: Seq[Dependency[_]] = deps - /** A friendly name for this RDD */ - var name: String = null - /** Optionally overridden by subclasses to specify placement preferences. */ protected def getPreferredLocations(split: Split): Seq[String] = Nil /** Optionally overridden by subclasses to specify how they are partitioned. */ val partitioner: Option[Partitioner] = None - // ======================================================================= // Methods and fields available on all RDDs // ======================================================================= @@ -111,13 +102,16 @@ abstract class RDD[T: ClassManifest]( /** A unique ID for this RDD (within its SparkContext). */ val id = sc.newRddId() + /** A friendly name for this RDD */ + var name: String = null + /** Assign a name to this RDD */ def setName(_name: String) = { name = _name this } - /** + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. */ @@ -142,15 +136,24 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel + // Our dependencies and splits will be gotten by calling subclass's methods below, and will + // be overwritten when we're checkpointed + private var dependencies_ : Seq[Dependency[_]] = null + @transient private var splits_ : Array[Split] = null + + /** An Option holding our checkpoint RDD, if we are checkpointed */ + private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) + /** - * Get the preferred location of a split, taking into account whether the + * Get the list of dependencies of this RDD, taking into account whether the * RDD is checkpointed or not. */ - final def preferredLocations(split: Split): Seq[String] = { - if (isCheckpointed) { - checkpointData.get.getPreferredLocations(split) - } else { - getPreferredLocations(split) + final def dependencies: Seq[Dependency[_]] = { + checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse { + if (dependencies_ == null) { + dependencies_ = getDependencies + } + dependencies_ } } @@ -159,22 +162,21 @@ abstract class RDD[T: ClassManifest]( * RDD is checkpointed or not. */ final def splits: Array[Split] = { - if (isCheckpointed) { - checkpointData.get.getSplits - } else { - getSplits + checkpointRDD.map(_.splits).getOrElse { + if (splits_ == null) { + splits_ = getSplits + } + splits_ } } /** - * Get the list of dependencies of this RDD, taking into account whether the + * Get the preferred location of a split, taking into account whether the * RDD is checkpointed or not. */ - final def dependencies: List[Dependency[_]] = { - if (isCheckpointed) { - dependencies_ - } else { - getDependencies + final def preferredLocations(split: Split): Seq[String] = { + checkpointRDD.map(_.getPreferredLocations(split)).getOrElse { + getPreferredLocations(split) } } @@ -184,10 +186,19 @@ abstract class RDD[T: ClassManifest]( * subclasses of RDD. */ final def iterator(split: Split, context: TaskContext): Iterator[T] = { - if (isCheckpointed) { - checkpointData.get.iterator(split, context) - } else if (storageLevel != StorageLevel.NONE) { + if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) + } else { + computeOrReadCheckpoint(split, context) + } + } + + /** + * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. + */ + private[spark] def computeOrReadCheckpoint(split: Split, context: TaskContext): Iterator[T] = { + if (isCheckpointed) { + firstParent[T].iterator(split, context) } else { compute(split, context) } @@ -578,15 +589,15 @@ abstract class RDD[T: ClassManifest]( /** * Return whether this RDD has been checkpointed or not */ - def isCheckpointed(): Boolean = { - if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false + def isCheckpointed: Boolean = { + checkpointData.map(_.isCheckpointed).getOrElse(false) } /** * Gets the name of the file to which this RDD was checkpointed */ - def getCheckpointFile(): Option[String] = { - if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None + def getCheckpointFile: Option[String] = { + checkpointData.flatMap(_.getCheckpointFile) } // ======================================================================= @@ -611,31 +622,36 @@ abstract class RDD[T: ClassManifest]( def context = sc /** - * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler + * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler * after a job using this RDD has completed (therefore the RDD has been materialized and * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. */ - protected[spark] def doCheckpoint() { - if (checkpointData.isDefined) checkpointData.get.doCheckpoint() - dependencies.foreach(_.rdd.doCheckpoint()) + private[spark] def doCheckpoint() { + if (checkpointData.isDefined) { + checkpointData.get.doCheckpoint() + } else { + dependencies.foreach(_.rdd.doCheckpoint()) + } } /** - * Changes the dependencies of this RDD from its original parents to the new RDD - * (`newRDD`) created from the checkpoint file. + * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`) + * created from the checkpoint file, and forget its old dependencies and splits. */ - protected[spark] def changeDependencies(newRDD: RDD[_]) { + private[spark] def markCheckpointed(checkpointRDD: RDD[_]) { clearDependencies() - dependencies_ = List(new OneToOneDependency(newRDD)) + dependencies_ = null + splits_ = null + deps = null // Forget the constructor argument for dependencies too } /** * Clears the dependencies of this RDD. This method must ensure that all references * to the original parent RDDs is removed to enable the parent RDDs to be garbage * collected. Subclasses of RDD may override this method for implementing their own cleaning - * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + * logic. See [[spark.rdd.UnionRDD]] for an example. */ - protected[spark] def clearDependencies() { + protected def clearDependencies() { dependencies_ = null } } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index 18df530b7d..a4a4ebaf53 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -20,7 +20,7 @@ private[spark] object CheckpointState extends Enumeration { * of the checkpointed RDD. */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) -extends Logging with Serializable { + extends Logging with Serializable { import CheckpointState._ @@ -31,7 +31,7 @@ extends Logging with Serializable { @transient var cpFile: Option[String] = None // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. - @transient var cpRDD: Option[RDD[T]] = None + var cpRDD: Option[RDD[T]] = None // Mark the RDD for checkpointing def markForCheckpoint() { @@ -41,12 +41,12 @@ extends Logging with Serializable { } // Is the RDD already checkpointed - def isCheckpointed(): Boolean = { + def isCheckpointed: Boolean = { RDDCheckpointData.synchronized { cpState == Checkpointed } } // Get the file to which this RDD was checkpointed to as an Option - def getCheckpointFile(): Option[String] = { + def getCheckpointFile: Option[String] = { RDDCheckpointData.synchronized { cpFile } } @@ -71,7 +71,7 @@ extends Logging with Serializable { RDDCheckpointData.synchronized { cpFile = Some(path) cpRDD = Some(newRDD) - rdd.changeDependencies(newRDD) + rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and splits cpState = Checkpointed RDDCheckpointData.clearTaskCaches() logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) @@ -79,7 +79,7 @@ extends Logging with Serializable { } // Get preferred location of a split after checkpointing - def getPreferredLocations(split: Split) = { + def getPreferredLocations(split: Split): Seq[String] = { RDDCheckpointData.synchronized { cpRDD.get.preferredLocations(split) } @@ -91,9 +91,10 @@ extends Logging with Serializable { } } - // Get iterator. This is called at the worker nodes. - def iterator(split: Split, context: TaskContext): Iterator[T] = { - rdd.firstParent[T].iterator(split, context) + def checkpointRDD: Option[RDD[T]] = { + RDDCheckpointData.synchronized { + cpRDD + } } } diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 4c95c989b5..46fd8fe85e 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -319,7 +319,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround /** * Return whether this RDD has been checkpointed or not */ - def isCheckpointed(): Boolean = rdd.isCheckpointed() + def isCheckpointed: Boolean = rdd.isCheckpointed /** * Gets the name of the file to which this RDD was checkpointed diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 453d410ad4..0f9ca06531 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,7 +1,7 @@ package spark.rdd import java.io.{ObjectOutputStream, IOException} -import spark.{OneToOneDependency, NarrowDependency, RDD, SparkContext, Split, TaskContext} +import spark._ private[spark] @@ -35,7 +35,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( val numSplitsInRdd2 = rdd2.splits.size - @transient var splits_ = { + override def getSplits: Array[Split] = { // create the cross product split val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { @@ -45,8 +45,6 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( array } - override def getSplits = splits_ - override def getPreferredLocations(split: Split) = { val currSplit = split.asInstanceOf[CartesianSplit] rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) @@ -58,7 +56,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) } - var deps_ = List( + override def getDependencies: Seq[Dependency[_]] = List( new NarrowDependency(rdd1) { def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) }, @@ -67,11 +65,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( } ) - override def getDependencies = deps_ - override def clearDependencies() { - deps_ = Nil - splits_ = null rdd1 = null rdd2 = null } diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 6f00f6ac73..96b593ba7c 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -9,23 +9,26 @@ import org.apache.hadoop.fs.Path import java.io.{File, IOException, EOFException} import java.text.NumberFormat -private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split { - override val index: Int = idx -} +private[spark] class CheckpointRDDSplit(val index: Int) extends Split {} /** * This RDD represents a RDD checkpoint file (similar to HadoopRDD). */ private[spark] -class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) +class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { - @transient val path = new Path(checkpointPath) - @transient val fs = path.getFileSystem(new Configuration()) + @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) @transient val splits_ : Array[Split] = { - val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted - splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray + val dirContents = fs.listStatus(new Path(checkpointPath)) + val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted + val numSplits = splitFiles.size + if (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || + !splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1))) { + throw new SparkException("Invalid checkpoint directory: " + checkpointPath) + } + Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i)) } checkpointData = Some(new RDDCheckpointData[T](this)) @@ -34,36 +37,34 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) override def getSplits = splits_ override def getPreferredLocations(split: Split): Seq[String] = { - val status = fs.getFileStatus(path) + val status = fs.getFileStatus(new Path(checkpointPath)) val locations = fs.getFileBlockLocations(status, 0, status.getLen) - locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost") + locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") } override def compute(split: Split, context: TaskContext): Iterator[T] = { - CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile, context) + val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)) + CheckpointRDD.readFromFile(file, context) } override def checkpoint() { - // Do nothing. Hadoop RDD should not be checkpointed. + // Do nothing. CheckpointRDD should not be checkpointed. } } private[spark] object CheckpointRDD extends Logging { - def splitIdToFileName(splitId: Int): String = { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - "part-" + numfmt.format(splitId) + def splitIdToFile(splitId: Int): String = { + "part-%05d".format(splitId) } - def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) { + def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { val outputDir = new Path(path) val fs = outputDir.getFileSystem(new Configuration()) - val finalOutputName = splitIdToFileName(context.splitId) + val finalOutputName = splitIdToFile(ctx.splitId) val finalOutputPath = new Path(outputDir, finalOutputName) - val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId) + val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) if (fs.exists(tempOutputPath)) { throw new IOException("Checkpoint failed: temporary path " + @@ -83,22 +84,22 @@ private[spark] object CheckpointRDD extends Logging { serializeStream.close() if (!fs.rename(tempOutputPath, finalOutputPath)) { - if (!fs.delete(finalOutputPath, true)) { - throw new IOException("Checkpoint failed: failed to delete earlier output of task " - + context.attemptId) - } - if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.exists(finalOutputPath)) { + fs.delete(tempOutputPath, false) throw new IOException("Checkpoint failed: failed to save output of task: " - + context.attemptId) + + ctx.attemptId + " and final output path does not exist") + } else { + // Some other copy of this task must've finished before us and renamed it + logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") + fs.delete(tempOutputPath, false) } } } - def readFromFile[T](path: String, context: TaskContext): Iterator[T] = { - val inputPath = new Path(path) - val fs = inputPath.getFileSystem(new Configuration()) + def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { + val fs = path.getFileSystem(new Configuration()) val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val fileInputStream = fs.open(inputPath, bufferSize) + val fileInputStream = fs.open(path, bufferSize) val serializer = SparkEnv.get.serializer.newInstance() val deserializeStream = serializer.deserializeStream(fileInputStream) diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 167755bbba..4c57434b65 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -27,11 +27,11 @@ private[spark] case class CoalescedRDDSplit( * or to avoid having a large number of small tasks when processing a directory with many files. */ class CoalescedRDD[T: ClassManifest]( - var prev: RDD[T], + @transient var prev: RDD[T], maxPartitions: Int) - extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs + extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies - @transient var splits_ : Array[Split] = { + override def getSplits: Array[Split] = { val prevSplits = prev.splits if (prevSplits.length < maxPartitions) { prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) } @@ -44,26 +44,20 @@ class CoalescedRDD[T: ClassManifest]( } } - override def getSplits = splits_ - override def compute(split: Split, context: TaskContext): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit => firstParent[T].iterator(parentSplit, context) } } - var deps_ : List[Dependency[_]] = List( + override def getDependencies: Seq[Dependency[_]] = List( new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices } ) - override def getDependencies() = deps_ - override def clearDependencies() { - deps_ = Nil - splits_ = null prev = null } } diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index c6ceb272cd..5466c9c657 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -3,13 +3,11 @@ package spark.rdd import spark.{RDD, Split, TaskContext} private[spark] -class MappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: T => U) +class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U) extends RDD[U](prev) { override def getSplits = firstParent[T].splits override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context).map(f) -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 97dd37950e..b8482338c6 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -7,23 +7,18 @@ import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. + * + * TODO: This currently doesn't give partition IDs properly! */ class PartitionPruningRDD[T: ClassManifest]( @transient prev: RDD[T], @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - @transient - var partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions - override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) - override protected def getSplits = partitions_ + override protected def getSplits = + getDependencies.head.asInstanceOf[PruneDependency[T]].partitions override val partitioner = firstParent[T].partitioner - - override def clearDependencies() { - super.clearDependencies() - partitions_ = null - } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 28ff19876d..d396478673 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -22,16 +22,10 @@ class ShuffledRDD[K, V]( override val partitioner = Some(part) - @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) - - override def getSplits = splits_ + override def getSplits = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } - - override def clearDependencies() { - splits_ = null - } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 82f0a44ecd..26a2d511f2 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -26,9 +26,9 @@ private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIn class UnionRDD[T: ClassManifest]( sc: SparkContext, @transient var rdds: Seq[RDD[T]]) - extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs + extends RDD[T](sc, Nil) { // Nil since we implement getDependencies - @transient var splits_ : Array[Split] = { + override def getSplits: Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { @@ -38,20 +38,16 @@ class UnionRDD[T: ClassManifest]( array } - override def getSplits = splits_ - - @transient var deps_ = { + override def getDependencies: Seq[Dependency[_]] = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) pos += rdd.splits.size } - deps.toList + deps } - override def getDependencies = deps_ - override def compute(s: Split, context: TaskContext): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator(context) @@ -59,8 +55,6 @@ class UnionRDD[T: ClassManifest]( s.asInstanceOf[UnionSplit[T]].preferredLocations() override def clearDependencies() { - deps_ = null - splits_ = null rdds = null } } diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index d950b06c85..e5df6d8c72 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -32,9 +32,7 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) with Serializable { - // TODO: FIX THIS. - - @transient var splits_ : Array[Split] = { + override def getSplits: Array[Split] = { if (rdd1.splits.size != rdd2.splits.size) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } @@ -45,8 +43,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( array } - override def getSplits = splits_ - override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = { val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context)) @@ -58,7 +54,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( } override def clearDependencies() { - splits_ = null rdd1 = null rdd2 = null } diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 6cf93a9b17..eaff7ae581 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -26,8 +26,8 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging if (delaySeconds > 0) { logDebug( - "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and " - + "period of " + periodSeconds + " secs") + "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " + + "and period of " + periodSeconds + " secs") timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 33c317720c..0b74607fb8 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. // Note that this test is very specific to the current implementation of CartesianRDD. val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint // checkpoint that MappedRDD + ones.checkpoint() // checkpoint that MappedRDD val cartesian = new CartesianRDD(sc, ones, ones) val splitBeforeCheckpoint = serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit]) @@ -125,7 +125,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. // Note that this test is very specific to the current implementation of CoalescedRDDSplits val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint // checkpoint that MappedRDD + ones.checkpoint() // checkpoint that MappedRDD val coalesced = new CoalescedRDD(ones, 2) val splitBeforeCheckpoint = serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit]) @@ -160,7 +160,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { // so only the RDD will reduce in serialized size, not the splits. testParentCheckpointing( rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) - } /** @@ -176,7 +175,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testRDDSplitSize: Boolean = false ) { // Generate the final RDD using given RDD operation - val baseRDD = generateLongLineageRDD + val baseRDD = generateLongLineageRDD() val operatedRDD = op(baseRDD) val parentRDD = operatedRDD.dependencies.headOption.orNull val rddType = operatedRDD.getClass.getSimpleName @@ -245,12 +244,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testRDDSplitSize: Boolean ) { // Generate the final RDD using given RDD operation - val baseRDD = generateLongLineageRDD + val baseRDD = generateLongLineageRDD() val operatedRDD = op(baseRDD) val parentRDD = operatedRDD.dependencies.head.rdd val rddType = operatedRDD.getClass.getSimpleName val parentRDDType = parentRDD.getClass.getSimpleName + // Get the splits and dependencies of the parent in case they're lazily computed + parentRDD.dependencies + parentRDD.splits + // Find serialized sizes before and after the checkpoint val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one @@ -267,7 +270,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { if (testRDDSize) { assert( rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType + + "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" ) } @@ -318,10 +321,12 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } /** - * Get serialized sizes of the RDD and its splits + * Get serialized sizes of the RDD and its splits, in order to test whether the size shrinks + * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. */ def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { - (Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size) + (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length, + Utils.serialize(rdd.splits).length) } /** -- cgit v1.2.3 From 16a0789e10d2ac714e7c623b026c4a58ca9678d6 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 29 Jan 2013 17:09:53 -0800 Subject: Remember ConnectionManagerId used to initiate SendingConnections. This prevents ConnectionManager from getting confused if a machine has multiple host names and the one getHostName() finds happens not to be the one that was passed from, e.g., the BlockManagerMaster. --- core/src/main/scala/spark/network/Connection.scala | 15 +++++++++++---- core/src/main/scala/spark/network/ConnectionManager.scala | 3 ++- 2 files changed, 13 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index c193bf7c8d..cd5b7d57f3 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -12,7 +12,14 @@ import java.net._ private[spark] -abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging { +abstract class Connection(val channel: SocketChannel, val selector: Selector, + val remoteConnectionManagerId: ConnectionManagerId) extends Logging { + def this(channel_ : SocketChannel, selector_ : Selector) = { + this(channel_, selector_, + ConnectionManagerId.fromSocketAddress( + channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + )) + } channel.configureBlocking(false) channel.socket.setTcpNoDelay(true) @@ -25,7 +32,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() - val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress) def key() = channel.keyFor(selector) @@ -103,8 +109,9 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex } -private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector) -extends Connection(SocketChannel.open, selector_) { +private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, + remoteId_ : ConnectionManagerId) +extends Connection(SocketChannel.open, selector_, remoteId_) { class Outbox(fair: Int = 0) { val messages = new Queue[Message]() diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 2ecd14f536..c7f226044d 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -299,7 +299,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector)) + val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, + new SendingConnection(inetSocketAddress, selector, connectionManagerId)) newConnection } val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) -- cgit v1.2.3 From f7de6978c14a331683e4a341fccd6e4c5e9fa523 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 29 Jan 2013 14:03:05 -0800 Subject: Use Mesos ExecutorIDs to hold SlaveIDs. Then we can safely use the Mesos ExecutorID as a Spark ExecutorID. --- .../spark/executor/MesosExecutorBackend.scala | 6 ++++- .../scheduler/mesos/MesosSchedulerBackend.scala | 30 ++++++++++++---------- 2 files changed, 21 insertions(+), 15 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala index 1ef88075ad..b981b26916 100644 --- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala @@ -32,7 +32,11 @@ private[spark] class MesosExecutorBackend(executor: Executor) logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) this.driver = driver val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) - executor.initialize(executorInfo.getExecutorId.getValue, slaveInfo.getHostname, properties) + executor.initialize( + slaveInfo.getId.getValue + "-" + executorInfo.getExecutorId.getValue, + slaveInfo.getHostname, + properties + ) } override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index f3467db86b..eab1c60e0b 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -51,7 +51,7 @@ private[spark] class MesosSchedulerBackend( val taskIdToSlaveId = new HashMap[Long, String] // An ExecutorInfo for our tasks - var executorInfo: ExecutorInfo = null + var execArgs: Array[Byte] = null override def start() { synchronized { @@ -70,12 +70,11 @@ private[spark] class MesosSchedulerBackend( } }.start() - executorInfo = createExecutorInfo() waitForRegister() } } - def createExecutorInfo(): ExecutorInfo = { + def createExecutorInfo(execId: String): ExecutorInfo = { val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( "Spark home is not set; set it through the spark.home system " + "property, the SPARK_HOME environment variable or the SparkContext constructor")) @@ -97,7 +96,7 @@ private[spark] class MesosSchedulerBackend( .setEnvironment(environment) .build() ExecutorInfo.newBuilder() - .setExecutorId(ExecutorID.newBuilder().setValue("default").build()) + .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) .addResources(memory) @@ -109,17 +108,20 @@ private[spark] class MesosSchedulerBackend( * containing all the spark.* system properties in the form of (String, String) pairs. */ private def createExecArg(): Array[Byte] = { - val props = new HashMap[String, String] - val iterator = System.getProperties.entrySet.iterator - while (iterator.hasNext) { - val entry = iterator.next - val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.")) { - props(key) = value + if (execArgs == null) { + val props = new HashMap[String, String] + val iterator = System.getProperties.entrySet.iterator + while (iterator.hasNext) { + val entry = iterator.next + val (key, value) = (entry.getKey.toString, entry.getValue.toString) + if (key.startsWith("spark.")) { + props(key) = value + } } + // Serialize the map as an array of (String, String) pairs + execArgs = Utils.serialize(props.toArray) } - // Serialize the map as an array of (String, String) pairs - return Utils.serialize(props.toArray) + return execArgs } override def offerRescinded(d: SchedulerDriver, o: OfferID) {} @@ -216,7 +218,7 @@ private[spark] class MesosSchedulerBackend( return MesosTaskInfo.newBuilder() .setTaskId(taskId) .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) - .setExecutor(executorInfo) + .setExecutor(createExecutorInfo(slaveId)) .setName(task.name) .addResources(cpuResource) .setData(ByteString.copyFrom(task.serializedTask)) -- cgit v1.2.3 From 252845d3046034d6e779bd7245d2f876debba8fd Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 30 Jan 2013 10:38:06 -0800 Subject: Remove remants of attempt to use slaveId-executorId in MesosExecutorBackend --- core/src/main/scala/spark/executor/MesosExecutorBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala index b981b26916..818d6d1dda 100644 --- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala @@ -33,7 +33,7 @@ private[spark] class MesosExecutorBackend(executor: Executor) this.driver = driver val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) executor.initialize( - slaveInfo.getId.getValue + "-" + executorInfo.getExecutorId.getValue, + executorInfo.getExecutorId.getValue, slaveInfo.getHostname, properties ) -- cgit v1.2.3 From 871476d506a2d543482defb923a42a2a01f206ab Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Wed, 30 Jan 2013 16:56:46 -0600 Subject: Include message and exitStatus if availalbe. --- core/src/main/scala/spark/deploy/worker/Worker.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 5a83a42daf..8b41620d98 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -134,7 +134,9 @@ private[spark] class Worker( val fullId = jobId + "/" + execId if (ExecutorState.isFinished(state)) { val executor = executors(fullId) - logInfo("Executor " + fullId + " finished with state " + state) + logInfo("Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) finishedExecutors(fullId) = executor executors -= fullId coresUsed -= executor.cores -- cgit v1.2.3 From fe3eceab5724bec0103471eb905bb9701120b04a Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Thu, 31 Jan 2013 13:30:41 -0800 Subject: Remove activation of profiles by default See the discussion at https://github.com/mesos/spark/pull/355 for why default profile activation is a problem. --- bagel/pom.xml | 11 ----------- core/pom.xml | 11 ----------- examples/pom.xml | 11 ----------- pom.xml | 11 ----------- repl-bin/pom.xml | 11 ----------- repl/pom.xml | 11 ----------- streaming/pom.xml | 11 ----------- 7 files changed, 77 deletions(-) (limited to 'core') diff --git a/bagel/pom.xml b/bagel/pom.xml index 5f58347204..a8256a6e8b 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -45,11 +45,6 @@ hadoop1 - - - !hadoopVersion - - org.spark-project @@ -77,12 +72,6 @@ hadoop2 - - - hadoopVersion - 2 - - org.spark-project diff --git a/core/pom.xml b/core/pom.xml index 862d3ec37a..873e8a1d0f 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -163,11 +163,6 @@ hadoop1 - - - !hadoopVersion - - org.apache.hadoop @@ -220,12 +215,6 @@ hadoop2 - - - hadoopVersion - 2 - - org.apache.hadoop diff --git a/examples/pom.xml b/examples/pom.xml index 4d43103475..f43af670c6 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -50,11 +50,6 @@ hadoop1 - - - !hadoopVersion - - org.spark-project @@ -88,12 +83,6 @@ hadoop2 - - - hadoopVersion - 2 - - org.spark-project diff --git a/pom.xml b/pom.xml index 3ea989a082..c6b9012dc6 100644 --- a/pom.xml +++ b/pom.xml @@ -499,11 +499,6 @@ hadoop1 - - - !hadoopVersion - - 1 @@ -521,12 +516,6 @@ hadoop2 - - - hadoopVersion - 2 - - 2 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index da91c0f3ab..0667b71cc7 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -70,11 +70,6 @@ hadoop1 - - - !hadoopVersion - - hadoop1 @@ -115,12 +110,6 @@ hadoop2 - - - hadoopVersion - 2 - - hadoop2 diff --git a/repl/pom.xml b/repl/pom.xml index 2dc96beaf5..4a296fa630 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -72,11 +72,6 @@ hadoop1 - - - !hadoopVersion - - hadoop1 @@ -128,12 +123,6 @@ hadoop2 - - - hadoopVersion - 2 - - hadoop2 diff --git a/streaming/pom.xml b/streaming/pom.xml index 3dae815e1a..6ee7e59df3 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -83,11 +83,6 @@ hadoop1 - - - !hadoopVersion - - org.spark-project @@ -115,12 +110,6 @@ hadoop2 - - - hadoopVersion - 2 - - org.spark-project -- cgit v1.2.3 From 418e36caa8fcd9a70026ab762ec709732fdebd6b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Thu, 31 Jan 2013 17:18:33 -0600 Subject: Add more private declarations. --- core/src/main/scala/spark/MapOutputTracker.scala | 2 +- .../scala/spark/deploy/master/MasterWebUI.scala | 22 ++++------- .../main/scala/spark/scheduler/DAGScheduler.scala | 46 +++++++++++----------- .../scala/spark/scheduler/ShuffleMapTask.scala | 3 +- .../spark/scheduler/cluster/ClusterScheduler.scala | 2 +- .../spark/scheduler/cluster/TaskSetManager.scala | 19 ++++----- .../spark/scheduler/local/LocalScheduler.scala | 4 +- .../main/scala/spark/util/MetadataCleaner.scala | 10 ++--- 8 files changed, 49 insertions(+), 59 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index aaf433b324..4735207585 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -170,7 +170,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea } } - def cleanup(cleanupTime: Long) { + private def cleanup(cleanupTime: Long) { mapStatuses.clearOldValues(cleanupTime) cachedSerializedStatuses.clearOldValues(cleanupTime) } diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index a01774f511..529f72e9da 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -45,13 +45,9 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) => val future = master ? RequestMasterState val jobInfo = for (masterState <- future.mapTo[MasterState]) yield { - masterState.activeJobs.find(_.id == jobId) match { - case Some(job) => job - case _ => masterState.completedJobs.find(_.id == jobId) match { - case Some(job) => job - case _ => null - } - } + masterState.activeJobs.find(_.id == jobId).getOrElse({ + masterState.completedJobs.find(_.id == jobId).getOrElse(null) + }) } respondWithMediaType(MediaTypes.`application/json`) { ctx => ctx.complete(jobInfo.mapTo[JobInfo]) @@ -61,14 +57,10 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct val future = master ? RequestMasterState future.map { state => val masterState = state.asInstanceOf[MasterState] - - masterState.activeJobs.find(_.id == jobId) match { - case Some(job) => spark.deploy.master.html.job_details.render(job) - case _ => masterState.completedJobs.find(_.id == jobId) match { - case Some(job) => spark.deploy.master.html.job_details.render(job) - case _ => null - } - } + val job = masterState.activeJobs.find(_.id == jobId).getOrElse({ + masterState.completedJobs.find(_.id == jobId).getOrElse(null) + }) + spark.deploy.master.html.job_details.render(job) } } } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b130be6a38..14f61f7e87 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -97,7 +97,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } }.start() - def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { + private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { @@ -107,7 +107,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with cacheLocs(rdd.id) } - def clearCacheLocs() { + private def clearCacheLocs() { cacheLocs.clear() } @@ -116,7 +116,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * The priority value passed in will be used if the stage doesn't already exist with * a lower priority (we assume that priorities always increase across jobs for now). */ - def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = { + private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => @@ -131,11 +131,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * as a result stage for the final RDD used directly in an action. The stage will also be given * the provided priority. */ - def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = { - // Kind of ugly: need to register RDDs with the cache and map output tracker here - // since we can't do it in the RDD constructor because # of splits is unknown - logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") + private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = { if (shuffleDep != None) { + // Kind of ugly: need to register RDDs with the cache and map output tracker here + // since we can't do it in the RDD constructor because # of splits is unknown + logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) } val id = nextStageId.getAndIncrement() @@ -148,7 +148,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * Get or create the list of parent stages for a given RDD. The stages will be assigned the * provided priority if they haven't already been created with a lower priority. */ - def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = { + private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] def visit(r: RDD[_]) { @@ -170,7 +170,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with parents.toList } - def getMissingParentStages(stage: Stage): List[Stage] = { + private def getMissingParentStages(stage: Stage): List[Stage] = { val missing = new HashSet[Stage] val visited = new HashSet[RDD[_]] def visit(rdd: RDD[_]) { @@ -241,7 +241,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * events and responds by launching tasks. This runs in a dedicated thread and receives events * via the eventQueue. */ - def run() { + private def run() { SparkEnv.set(env) while (true) { @@ -326,7 +326,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * We run the operation in a separate thread just in case it takes a bunch of time, so that we * don't block the DAGScheduler event loop or other concurrent jobs. */ - def runLocally(job: ActiveJob) { + private def runLocally(job: ActiveJob) { logInfo("Computing the requested partition locally") new Thread("Local computation of job " + job.runId) { override def run() { @@ -349,13 +349,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with }.start() } - def submitStage(stage: Stage) { + /** Submits stage, but first recursively submits any missing parents. */ + private def submitStage(stage: Stage) { logDebug("submitStage(" + stage + ")") if (!waiting(stage) && !running(stage) && !failed(stage)) { val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) if (missing == Nil) { - logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents") + logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") submitMissingTasks(stage) running += stage } else { @@ -367,7 +368,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } } - def submitMissingTasks(stage: Stage) { + /** Called when stage's parents are available and we can now do its task. */ + private def submitMissingTasks(stage: Stage) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) @@ -388,7 +390,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } } if (tasks.size > 0) { - logInfo("Submitting " + tasks.size + " missing tasks from " + stage) + logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") myPending ++= tasks logDebug("New pending tasks: " + myPending) taskSched.submitTasks( @@ -407,7 +409,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * Responds to a task finishing. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. */ - def handleTaskCompletion(event: CompletionEvent) { + private def handleTaskCompletion(event: CompletionEvent) { val task = event.task val stage = idToStage(task.stageId) @@ -492,7 +494,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with waiting --= newlyRunnable running ++= newlyRunnable for (stage <- newlyRunnable.sortBy(_.id)) { - logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable") + logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") submitMissingTasks(stage) } } @@ -541,7 +543,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * Optionally the generation during which the failure was caught can be passed to avoid allowing * stray fetch failures from possibly retriggering the detection of a node as lost. */ - def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) { + private def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) { val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) { failedGeneration(execId) = currentGeneration @@ -567,7 +569,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ - def abortStage(failedStage: Stage, reason: String) { + private def abortStage(failedStage: Stage, reason: String) { val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) @@ -583,7 +585,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with /** * Return true if one of stage's ancestors is target. */ - def stageDependsOn(stage: Stage, target: Stage): Boolean = { + private def stageDependsOn(stage: Stage, target: Stage): Boolean = { if (stage == target) { return true } @@ -610,7 +612,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with visitedRdds.contains(target.rdd) } - def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { + private def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) if (cached != Nil) { @@ -636,7 +638,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return Nil } - def cleanup(cleanupTime: Long) { + private def cleanup(cleanupTime: Long) { var sizeBefore = idToStage.size idToStage.clearOldValues(cleanupTime) logInfo("idToStage " + sizeBefore + " --> " + idToStage.size) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 83641a2a84..b701b67c89 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -127,7 +127,6 @@ private[spark] class ShuffleMapTask( val bucketId = dep.partitioner.getPartition(pair._1) buckets(bucketId) += pair } - val bucketIterators = buckets.map(_.iterator) val compressedSizes = new Array[Byte](numOutputSplits) @@ -135,7 +134,7 @@ private[spark] class ShuffleMapTask( for (i <- 0 until numOutputSplits) { val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i // Get a Scala iterator from Java map - val iter: Iterator[(Any, Any)] = bucketIterators(i) + val iter: Iterator[(Any, Any)] = buckets(i).iterator val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) compressedSizes(i) = MapOutputTracker.compressSize(size) } diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 0b4177805b..1e4fbdb874 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -86,7 +86,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - def submitTasks(taskSet: TaskSet) { + override def submitTasks(taskSet: TaskSet) { val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 26201ad0dd..3dabdd76b1 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -17,10 +17,7 @@ import java.nio.ByteBuffer /** * Schedules the tasks within a single TaskSet in the ClusterScheduler. */ -private[spark] class TaskSetManager( - sched: ClusterScheduler, - val taskSet: TaskSet) - extends Logging { +private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging { // Maximum time to wait to run a task in a preferred location (in ms) val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong @@ -100,7 +97,7 @@ private[spark] class TaskSetManager( } // Add a task to all the pending-task lists that it should be on. - def addPendingTask(index: Int) { + private def addPendingTask(index: Int) { val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive if (locations.size == 0) { pendingTasksWithNoPrefs += index @@ -115,7 +112,7 @@ private[spark] class TaskSetManager( // Return the pending tasks list for a given host, or an empty list if // there is no map entry for that host - def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { + private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { pendingTasksForHost.getOrElse(host, ArrayBuffer()) } @@ -123,7 +120,7 @@ private[spark] class TaskSetManager( // Return None if the list is empty. // This method also cleans up any tasks in the list that have already // been launched, since we want that to happen lazily. - def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { + private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { while (!list.isEmpty) { val index = list.last list.trimEnd(1) @@ -137,7 +134,7 @@ private[spark] class TaskSetManager( // Return a speculative task for a given host if any are available. The task should not have an // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the // task must have a preference for this host (or no preferred locations at all). - def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { + private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { val hostsAlive = sched.hostsAlive speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set val localTask = speculatableTasks.find { @@ -162,7 +159,7 @@ private[spark] class TaskSetManager( // Dequeue a pending task for a given node and return its index. // If localOnly is set to false, allow non-local tasks as well. - def findTask(host: String, localOnly: Boolean): Option[Int] = { + private def findTask(host: String, localOnly: Boolean): Option[Int] = { val localTask = findTaskFromList(getPendingTasksForHost(host)) if (localTask != None) { return localTask @@ -184,7 +181,7 @@ private[spark] class TaskSetManager( // Does a host count as a preferred location for a task? This is true if // either the task has preferred locations and this host is one, or it has // no preferred locations (in which we still count the launch as preferred). - def isPreferredLocation(task: Task[_], host: String): Boolean = { + private def isPreferredLocation(task: Task[_], host: String): Boolean = { val locs = task.preferredLocations return (locs.contains(host) || locs.isEmpty) } @@ -335,7 +332,7 @@ private[spark] class TaskSetManager( if (numFailures(index) > MAX_TASK_FAILURES) { logError("Task %s:%d failed more than %d times; aborting job".format( taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES)) + abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) } } } else { diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 9ff7c02097..482d1cc853 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon } def runTask(task: Task[_], idInJob: Int, attemptId: Int) { - logInfo("Running task " + idInJob) + logInfo("Running " + task) // Set the Spark execution environment for the worker thread SparkEnv.set(env) try { @@ -80,7 +80,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon val resultToReturn = ser.deserialize[Any](ser.serialize(result)) val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( ser.serialize(Accumulators.values)) - logInfo("Finished task " + idInJob) + logInfo("Finished " + task) // If the threadpool has not already been shutdown, notify DAGScheduler if (!Thread.currentThread().isInterrupted) diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index eaff7ae581..a342d378ff 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -9,12 +9,12 @@ import spark.Logging * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) */ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delaySeconds = MetadataCleaner.getDelaySeconds - val periodSeconds = math.max(10, delaySeconds / 10) - val timer = new Timer(name + " cleanup timer", true) + private val delaySeconds = MetadataCleaner.getDelaySeconds + private val periodSeconds = math.max(10, delaySeconds / 10) + private val timer = new Timer(name + " cleanup timer", true) - val task = new TimerTask { - def run() { + private val task = new TimerTask { + override def run() { try { cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) logInfo("Ran metadata cleaner for " + name) -- cgit v1.2.3 From 5b0fc265c2f2ce461d61904c2a4e6e47b24d2bbe Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 31 Jan 2013 17:48:39 -0800 Subject: Changed PartitionPruningRDD's split to make sure it returns the correct split index. --- core/src/main/scala/spark/Dependency.scala | 8 ++++++++ core/src/main/scala/spark/rdd/PartitionPruningRDD.scala | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index 647aee6eb5..827eac850a 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -72,6 +72,14 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo @transient val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index)) + .zipWithIndex + .map { case(split, idx) => new PruneDependency.PartitionPruningRDDSplit(idx, split) : Split } override def getParents(partitionId: Int) = List(partitions(partitionId).index) } + +object PruneDependency { + class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split { + override val index = idx + } +} diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index b8482338c6..0989b149e1 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -2,6 +2,7 @@ package spark.rdd import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} + /** * A RDD used to prune RDD partitions/splits so we can avoid launching tasks on * all partitions. An example use case: If we know the RDD is partitioned by range, @@ -15,7 +16,8 @@ class PartitionPruningRDD[T: ClassManifest]( @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) + override def compute(split: Split, context: TaskContext) = firstParent[T].iterator( + split.asInstanceOf[PruneDependency.PartitionPruningRDDSplit].parentSplit, context) override protected def getSplits = getDependencies.head.asInstanceOf[PruneDependency[T]].partitions -- cgit v1.2.3 From 6289d9654e32fc92418d41cc6e32fee30f85c833 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 31 Jan 2013 17:49:36 -0800 Subject: Removed the TODO comment from PartitionPruningRDD. --- core/src/main/scala/spark/rdd/PartitionPruningRDD.scala | 2 -- 1 file changed, 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 0989b149e1..3756870fac 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -8,8 +8,6 @@ import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. - * - * TODO: This currently doesn't give partition IDs properly! */ class PartitionPruningRDD[T: ClassManifest]( @transient prev: RDD[T], -- cgit v1.2.3 From 3446d5c8d6b385106ac85e46320d92faa8efb4e6 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 31 Jan 2013 18:02:28 -0800 Subject: SPARK-673: Capture and re-throw Python exceptions This patch alters the Python <-> executor protocol to pass on exception data when they occur in user Python code. --- .../main/scala/spark/api/python/PythonRDD.scala | 40 ++++++++++++++-------- python/pyspark/worker.py | 10 ++++-- 2 files changed, 34 insertions(+), 16 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index f43a152ca7..6b9ef62529 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -103,21 +103,30 @@ private[spark] class PythonRDD[T: ClassManifest]( private def read(): Array[Byte] = { try { - val length = stream.readInt() - if (length != -1) { - val obj = new Array[Byte](length) - stream.readFully(obj) - obj - } else { - // We've finished the data section of the output, but we can still read some - // accumulator updates; let's do that, breaking when we get EOFException - while (true) { - val len2 = stream.readInt() - val update = new Array[Byte](len2) - stream.readFully(update) - accumulator += Collections.singletonList(update) + stream.readInt() match { + case length if length > 0 => { + val obj = new Array[Byte](length) + stream.readFully(obj) + obj } - new Array[Byte](0) + case -2 => { + // Signals that an exception has been thrown in python + val exLength = stream.readInt() + val obj = new Array[Byte](exLength) + stream.readFully(obj) + throw new PythonException(new String(obj)) + } + case -1 => { + // We've finished the data section of the output, but we can still read some + // accumulator updates; let's do that, breaking when we get EOFException + while (true) { + val len2 = stream.readInt() + val update = new Array[Byte](len2) + stream.readFully(update) + accumulator += Collections.singletonList(update) + } + new Array[Byte](0) + } } } catch { case eof: EOFException => { @@ -140,6 +149,9 @@ private[spark] class PythonRDD[T: ClassManifest]( val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } +/** Thrown for exceptions in user Python code. */ +private class PythonException(msg: String) extends Exception(msg) + /** * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. * This is used by PySpark's shuffle operations. diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d33d6dd15f..9622e0cfe4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -2,6 +2,7 @@ Worker that receives input from Piped RDD. """ import sys +import traceback from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. @@ -40,8 +41,13 @@ def main(): else: dumps = dump_pickle iterator = read_from_pickle_file(sys.stdin) - for obj in func(split_index, iterator): - write_with_length(dumps(obj), old_stdout) + try: + for obj in func(split_index, iterator): + write_with_length(dumps(obj), old_stdout) + except Exception as e: + write_int(-2, old_stdout) + write_with_length(traceback.format_exc(), old_stdout) + sys.exit(-1) # Mark the beginning of the accumulators section of the output write_int(-1, old_stdout) for aid, accum in _accumulatorRegistry.items(): -- cgit v1.2.3 From c33f0ef41a1865de2bae01b52b860650d3734da4 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 31 Jan 2013 21:50:02 -0800 Subject: Some style cleanup --- core/src/main/scala/spark/api/python/PythonRDD.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 6b9ef62529..23e3149248 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -104,19 +104,17 @@ private[spark] class PythonRDD[T: ClassManifest]( private def read(): Array[Byte] = { try { stream.readInt() match { - case length if length > 0 => { + case length if length > 0 => val obj = new Array[Byte](length) stream.readFully(obj) obj - } - case -2 => { + case -2 => // Signals that an exception has been thrown in python val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) throw new PythonException(new String(obj)) - } - case -1 => { + case -1 => // We've finished the data section of the output, but we can still read some // accumulator updates; let's do that, breaking when we get EOFException while (true) { @@ -124,9 +122,8 @@ private[spark] class PythonRDD[T: ClassManifest]( val update = new Array[Byte](len2) stream.readFully(update) accumulator += Collections.singletonList(update) + new Array[Byte](0) } - new Array[Byte](0) - } } } catch { case eof: EOFException => { -- cgit v1.2.3 From 39ab83e9577a5449fb0d6ef944dffc0d7cd00b4a Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 31 Jan 2013 21:52:52 -0800 Subject: Small fix from last commit --- core/src/main/scala/spark/api/python/PythonRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 23e3149248..39758e94f4 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -122,8 +122,8 @@ private[spark] class PythonRDD[T: ClassManifest]( val update = new Array[Byte](len2) stream.readFully(update) accumulator += Collections.singletonList(update) - new Array[Byte](0) } + new Array[Byte](0) } } catch { case eof: EOFException => { -- cgit v1.2.3 From f9af9cee6fed9c6af896fb92556ad4f48c7f8e64 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 1 Feb 2013 00:02:46 -0800 Subject: Moved PruneDependency into PartitionPruningRDD.scala. --- core/src/main/scala/spark/Dependency.scala | 22 ------------------ .../main/scala/spark/rdd/PartitionPruningRDD.scala | 26 ++++++++++++++++++---- 2 files changed, 22 insertions(+), 26 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index 827eac850a..5eea907322 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -61,25 +61,3 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) } } } - - -/** - * Represents a dependency between the PartitionPruningRDD and its parent. In this - * case, the child RDD contains a subset of partitions of the parents'. - */ -class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) - extends NarrowDependency[T](rdd) { - - @transient - val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index)) - .zipWithIndex - .map { case(split, idx) => new PruneDependency.PartitionPruningRDDSplit(idx, split) : Split } - - override def getParents(partitionId: Int) = List(partitions(partitionId).index) -} - -object PruneDependency { - class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split { - override val index = idx - } -} diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 3756870fac..a50ce75171 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -1,6 +1,26 @@ package spark.rdd -import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} +import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext} + + +class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split { + override val index = idx +} + + +/** + * Represents a dependency between the PartitionPruningRDD and its parent. In this + * case, the child RDD contains a subset of partitions of the parents'. + */ +class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) + extends NarrowDependency[T](rdd) { + + @transient + val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index)) + .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split } + + override def getParents(partitionId: Int) = List(partitions(partitionId).index) +} /** @@ -15,10 +35,8 @@ class PartitionPruningRDD[T: ClassManifest]( extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { override def compute(split: Split, context: TaskContext) = firstParent[T].iterator( - split.asInstanceOf[PruneDependency.PartitionPruningRDDSplit].parentSplit, context) + split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context) override protected def getSplits = getDependencies.head.asInstanceOf[PruneDependency[T]].partitions - - override val partitioner = firstParent[T].partitioner } -- cgit v1.2.3 From 59c57e48dfb362923610785b230d5b3b56c620c3 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 1 Feb 2013 10:34:02 -0600 Subject: Stop BlockManagers metadataCleaner. --- core/src/main/scala/spark/storage/BlockManager.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index c61fd75c2b..9893e9625d 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -950,6 +950,7 @@ class BlockManager( blockInfo.clear() memoryStore.clear() diskStore.clear() + metadataCleaner.cancel() logInfo("BlockManager stopped") } } -- cgit v1.2.3 From 9970926ede0d5a719b8f22e97977804d3c811e97 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 1 Feb 2013 14:07:34 -0800 Subject: formatting --- core/src/main/scala/spark/RDD.scala | 2 +- core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 210404d540..010e61dfdc 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -385,7 +385,7 @@ abstract class RDD[T: ClassManifest]( val reducePartition: Iterator[T] => Option[T] = iter => { if (iter.hasNext) { Some(iter.reduceLeft(cleanF)) - }else { + } else { None } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 83641a2a84..20f2c9e489 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -32,7 +32,7 @@ private[spark] object ShuffleMapTask { return old } else { val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objOut = ser.serializeStream(new GZIPOutputStream(out)) objOut.writeObject(rdd) objOut.writeObject(dep) @@ -48,7 +48,7 @@ private[spark] object ShuffleMapTask { synchronized { val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] -- cgit v1.2.3 From 8b3041c7233011c4a96fab045a86df91eae7b6f3 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 1 Feb 2013 15:38:42 -0800 Subject: Reduced the memory usage of reduce and similar operations These operations used to wait for all the results to be available in an array on the driver program before merging them. They now merge values incrementally as they arrive. --- core/src/main/scala/spark/PairRDDFunctions.scala | 4 +- core/src/main/scala/spark/RDD.scala | 41 ++++++++++------- core/src/main/scala/spark/SparkContext.scala | 53 +++++++++++++++++++--- core/src/main/scala/spark/Utils.scala | 8 ++++ .../spark/partial/ApproximateActionListener.scala | 4 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 15 +++--- .../src/main/scala/spark/scheduler/JobResult.scala | 2 +- .../src/main/scala/spark/scheduler/JobWaiter.scala | 14 +++--- core/src/test/scala/spark/RDDSuite.scala | 12 +++-- 9 files changed, 107 insertions(+), 46 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 231e23a7de..cc3cca2571 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -465,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( val res = self.context.runJob(self, process _, Array(index), false) res(0) case None => - self.filter(_._1 == key).map(_._2).collect + self.filter(_._1 == key).map(_._2).collect() } } @@ -590,7 +590,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( var count = 0 while(iter.hasNext) { - val record = iter.next + val record = iter.next() count += 1 writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) } diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 010e61dfdc..9d6ea782bd 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -389,16 +389,18 @@ abstract class RDD[T: ClassManifest]( None } } - val options = sc.runJob(this, reducePartition) - val results = new ArrayBuffer[T] - for (opt <- options; elem <- opt) { - results += elem - } - if (results.size == 0) { - throw new UnsupportedOperationException("empty collection") - } else { - return results.reduceLeft(cleanF) + var jobResult: Option[T] = None + val mergeResult = (index: Int, taskResult: Option[T]) => { + if (taskResult != None) { + jobResult = jobResult match { + case Some(value) => Some(f(value, taskResult.get)) + case None => taskResult + } + } } + sc.runJob(this, reducePartition, mergeResult) + // Get the final result out of our Option, or throw an exception if the RDD was empty + jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) } /** @@ -408,9 +410,13 @@ abstract class RDD[T: ClassManifest]( * modify t2. */ def fold(zeroValue: T)(op: (T, T) => T): T = { + // Clone the zero value since we will also be serializing it as part of tasks + var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) val cleanOp = sc.clean(op) - val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)) - return results.fold(zeroValue)(cleanOp) + val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp) + val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult) + sc.runJob(this, foldPartition, mergeResult) + jobResult } /** @@ -422,11 +428,14 @@ abstract class RDD[T: ClassManifest]( * allocation. */ def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { + // Clone the zero value since we will also be serializing it as part of tasks + var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) val cleanSeqOp = sc.clean(seqOp) val cleanCombOp = sc.clean(combOp) - val results = sc.runJob(this, - (iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)) - return results.fold(zeroValue)(cleanCombOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) + sc.runJob(this, aggregatePartition, mergeResult) + jobResult } /** @@ -437,7 +446,7 @@ abstract class RDD[T: ClassManifest]( var result = 0L while (iter.hasNext) { result += 1L - iter.next + iter.next() } result }).sum @@ -452,7 +461,7 @@ abstract class RDD[T: ClassManifest]( var result = 0L while (iter.hasNext) { result += 1L - iter.next + iter.next() } result } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index b0d4b58240..ddbf8f95d9 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -543,26 +543,42 @@ class SparkContext( } /** - * Run a function on a given set of partitions in an RDD and return the results. This is the main - * entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies - * whether the scheduler can run the computation on the driver rather than shipping it out to the - * cluster, for short actions like first(). + * Run a function on a given set of partitions in an RDD and pass the results to the given + * handler function. This is the main entry point for all actions in Spark. The allowLocal + * flag specifies whether the scheduler can run the computation on the driver rather than + * shipping it out to the cluster, for short actions like first(). */ def runJob[T, U: ClassManifest]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], - allowLocal: Boolean - ): Array[U] = { + allowLocal: Boolean, + resultHandler: (Int, U) => Unit) { val callSite = Utils.getSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal) + val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() result } + /** + * Run a function on a given set of partitions in an RDD and return the results as an array. The + * allowLocal flag specifies whether the scheduler can run the computation on the driver rather + * than shipping it out to the cluster, for short actions like first(). + */ + def runJob[T, U: ClassManifest]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + allowLocal: Boolean + ): Array[U] = { + val results = new Array[U](partitions.size) + runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res) + results + } + /** * Run a job on a given set of partitions of an RDD, but take a function of type * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. @@ -590,6 +606,29 @@ class SparkContext( runJob(rdd, func, 0 until rdd.splits.size, false) } + /** + * Run a job on all partitions in an RDD and pass the results to a handler function. + */ + def runJob[T, U: ClassManifest]( + rdd: RDD[T], + processPartition: (TaskContext, Iterator[T]) => U, + resultHandler: (Int, U) => Unit) + { + runJob[T, U](rdd, processPartition, 0 until rdd.splits.size, false, resultHandler) + } + + /** + * Run a job on all partitions in an RDD and pass the results to a handler function. + */ + def runJob[T, U: ClassManifest]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + resultHandler: (Int, U) => Unit) + { + val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) + runJob[T, U](rdd, processFunc, 0 until rdd.splits.size, false, resultHandler) + } + /** * Run a job that can return approximate results. */ diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 1e58d01273..28d643abca 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -12,6 +12,7 @@ import scala.io.Source import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder import scala.Some +import spark.serializer.SerializerInstance /** * Various utility methods used by Spark. @@ -446,4 +447,11 @@ private object Utils extends Logging { socket.close() portBound } + + /** + * Clone an object using a Spark serializer. + */ + def clone[T](value: T, serializer: SerializerInstance): T = { + serializer.deserialize[T](serializer.serialize(value)) + } } diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala index 42f46e06ed..24b4909380 100644 --- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala +++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala @@ -32,7 +32,7 @@ private[spark] class ApproximateActionListener[T, U, R]( if (finishedTasks == totalTasks) { // If we had already returned a PartialResult, set its final value resultObject.foreach(r => r.setFinalValue(evaluator.currentResult())) - // Notify any waiting thread that may have called getResult + // Notify any waiting thread that may have called awaitResult this.notifyAll() } } @@ -49,7 +49,7 @@ private[spark] class ApproximateActionListener[T, U, R]( * Waits for up to timeout milliseconds since the listener was created and then returns a * PartialResult with the result so far. This may be complete if the whole job is done. */ - def getResult(): PartialResult[R] = synchronized { + def awaitResult(): PartialResult[R] = synchronized { val finishTime = startTime + timeout while (true) { val time = System.currentTimeMillis() diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 14f61f7e87..908a22b2df 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -203,18 +203,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: String, - allowLocal: Boolean) - : Array[U] = + allowLocal: Boolean, + resultHandler: (Int, U) => Unit) { if (partitions.size == 0) { - return new Array[U](0) + return } - val waiter = new JobWaiter(partitions.size) + val waiter = new JobWaiter(partitions.size, resultHandler) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)) - waiter.getResult() match { - case JobSucceeded(results: Seq[_]) => - return results.asInstanceOf[Seq[U]].toArray + waiter.awaitResult() match { + case JobSucceeded => {} case JobFailed(exception: Exception) => logInfo("Failed to run " + callSite) throw exception @@ -233,7 +232,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.splits.size).toArray eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener)) - return listener.getResult() // Will throw an exception if the job fails + return listener.awaitResult() // Will throw an exception if the job fails } /** diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala index c4a74e526f..654131ee84 100644 --- a/core/src/main/scala/spark/scheduler/JobResult.scala +++ b/core/src/main/scala/spark/scheduler/JobResult.scala @@ -5,5 +5,5 @@ package spark.scheduler */ private[spark] sealed trait JobResult -private[spark] case class JobSucceeded(results: Seq[_]) extends JobResult +private[spark] case object JobSucceeded extends JobResult private[spark] case class JobFailed(exception: Exception) extends JobResult diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala index b3d4feebe5..3cc6a86345 100644 --- a/core/src/main/scala/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/spark/scheduler/JobWaiter.scala @@ -3,10 +3,12 @@ package spark.scheduler import scala.collection.mutable.ArrayBuffer /** - * An object that waits for a DAGScheduler job to complete. + * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their + * results to the given handler function. */ -private[spark] class JobWaiter(totalTasks: Int) extends JobListener { - private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null) +private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit) + extends JobListener { + private var finishedTasks = 0 private var jobFinished = false // Is the job as a whole finished (succeeded or failed)? @@ -17,11 +19,11 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener { if (jobFinished) { throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") } - taskResults(index) = result + resultHandler(index, result.asInstanceOf[T]) finishedTasks += 1 if (finishedTasks == totalTasks) { jobFinished = true - jobResult = JobSucceeded(taskResults) + jobResult = JobSucceeded this.notifyAll() } } @@ -38,7 +40,7 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener { } } - def getResult(): JobResult = synchronized { + def awaitResult(): JobResult = synchronized { while (!jobFinished) { this.wait() } diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index ed03e65153..95d2e62730 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -12,9 +12,9 @@ class RDDSuite extends FunSuite with LocalSparkContext { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) - assert(dups.distinct.count === 4) - assert(dups.distinct().collect === dups.distinct.collect) - assert(dups.distinct(2).collect === dups.distinct.collect) + assert(dups.distinct().count === 4) + assert(dups.distinct().collect === dups.distinct().collect) + assert(dups.distinct(2).collect === dups.distinct().collect) assert(nums.reduce(_ + _) === 10) assert(nums.fold(0)(_ + _) === 10) assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) @@ -31,6 +31,10 @@ class RDDSuite extends FunSuite with LocalSparkContext { case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) } assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) + + intercept[UnsupportedOperationException] { + nums.filter(_ > 5).reduce(_ + _) + } } test("SparkContext.union") { @@ -164,7 +168,7 @@ class RDDSuite extends FunSuite with LocalSparkContext { // Note that split number starts from 0, so > 8 means only 10th partition left. val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) assert(prunedRdd.splits.size === 1) - val prunedData = prunedRdd.collect + val prunedData = prunedRdd.collect() assert(prunedData.size === 1) assert(prunedData(0) === 10) } -- cgit v1.2.3 From 12c1eb47568060efac57d6df7df7e5704a8d3fab Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 1 Feb 2013 21:21:44 -0600 Subject: Reduce the amount of duplicate logging Akka does to stdout. Given we have Akka logging go through SLF4j to log4j, we don't need all the extra noise of Akka's stdout logger that is supposedly only used during Akka init time but seems to continue logging lots of noisy network events that we either don't care about or are in the log4j logs anyway. See: http://doc.akka.io/docs/akka/2.0/general/configuration.html # Log level for the very basic logger activated during AkkaApplication startup # Options: ERROR, WARNING, INFO, DEBUG # stdout-loglevel = "WARNING" --- core/src/main/scala/spark/util/AkkaUtils.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index e0fdeffbc4..e43fbd6b1c 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -30,6 +30,7 @@ private[spark] object AkkaUtils { val akkaConf = ConfigFactory.parseString(""" akka.daemonic = on akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] + akka.stdout-loglevel = "ERROR" akka.actor.provider = "akka.remote.RemoteActorRefProvider" akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" akka.remote.log-remote-lifecycle-events = on -- cgit v1.2.3 From ae26911ec0d768dcdae8b7d706ca4544e36535e6 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 1 Feb 2013 21:07:24 -0800 Subject: Add back test for distinct without parens --- core/src/test/scala/spark/RDDSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 95d2e62730..89a3687386 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -12,7 +12,8 @@ class RDDSuite extends FunSuite with LocalSparkContext { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) - assert(dups.distinct().count === 4) + assert(dups.distinct().count() === 4) + assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses? assert(dups.distinct().collect === dups.distinct().collect) assert(dups.distinct(2).collect === dups.distinct().collect) assert(nums.reduce(_ + _) === 10) -- cgit v1.2.3