From 47b7ebad12a17218f6ca0301fc802c0e0a81d873 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 28 Jul 2012 20:03:26 -0700 Subject: Added the Spark Streaing code, ported to Akka 2 --- core/src/main/scala/spark/BlockRDD.scala | 42 ++++++++++++++++++++++++++++ core/src/main/scala/spark/SparkContext.scala | 5 ++++ 2 files changed, 47 insertions(+) create mode 100644 core/src/main/scala/spark/BlockRDD.scala (limited to 'core') diff --git a/core/src/main/scala/spark/BlockRDD.scala b/core/src/main/scala/spark/BlockRDD.scala new file mode 100644 index 0000000000..ea009f0f4f --- /dev/null +++ b/core/src/main/scala/spark/BlockRDD.scala @@ -0,0 +1,42 @@ +package spark + +import scala.collection.mutable.HashMap + +class BlockRDDSplit(val blockId: String, idx: Int) extends Split { + val index = idx +} + + +class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) { + + @transient + val splits_ = (0 until blockIds.size).map(i => { + new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] + }).toArray + + @transient + lazy val locations_ = { + val blockManager = SparkEnv.get.blockManager + /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ + val locations = blockManager.getLocations(blockIds) + HashMap(blockIds.zip(locations):_*) + } + + override def splits = splits_ + + override def compute(split: Split): Iterator[T] = { + val blockManager = SparkEnv.get.blockManager + val blockId = split.asInstanceOf[BlockRDDSplit].blockId + blockManager.get(blockId) match { + case Some(block) => block.asInstanceOf[Iterator[T]] + case None => + throw new Exception("Could not compute split, block " + blockId + " not found") + } + } + + override def preferredLocations(split: Split) = + locations_(split.asInstanceOf[BlockRDDSplit].blockId) + + override val dependencies: List[Dependency[_]] = Nil +} + diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index dd17d4d6b3..78c7618542 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -409,6 +409,11 @@ class SparkContext( * various Spark features. */ object SparkContext { + + // TODO: temporary hack for using HDFS as input in streaing + var inputFile: String = null + var idealPartitions: Int = 1 + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 -- cgit v1.2.3 From 3be54c2a8afcb2a3abf1cf22934123fae3419278 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 1 Aug 2012 22:09:27 -0700 Subject: 1. Refactored SparkStreamContext, Scheduler, InputRDS, FileInputRDS and a few other files. 2. Modified Time class to represent milliseconds (long) directly, instead of LongTime. 3. Added new files QueueInputRDS, RecurringTimer, etc. 4. Added RDDSuite as the skeleton for testcases. 5. Added two examples in spark.streaming.examples. 6. Removed all past examples and a few unnecessary files. Moved a number of files to spark.streaming.util. --- core/src/main/scala/spark/Utils.scala | 2 +- .../src/main/scala/spark/streaming/BlockID.scala | 20 -- .../main/scala/spark/streaming/FileInputRDS.scala | 163 +++++++++++ .../scala/spark/streaming/FileStreamReceiver.scala | 70 ----- .../scala/spark/streaming/IdealPerformance.scala | 36 --- .../src/main/scala/spark/streaming/Interval.scala | 6 +- streaming/src/main/scala/spark/streaming/Job.scala | 9 +- .../main/scala/spark/streaming/JobManager.scala | 23 +- .../spark/streaming/NetworkStreamReceiver.scala | 184 ------------- .../scala/spark/streaming/PairRDSFunctions.scala | 72 +++++ .../main/scala/spark/streaming/QueueInputRDS.scala | 36 +++ streaming/src/main/scala/spark/streaming/RDS.scala | 305 ++++++--------------- .../src/main/scala/spark/streaming/Scheduler.scala | 171 ++---------- .../scala/spark/streaming/SparkStreamContext.scala | 150 +++++++--- .../spark/streaming/TestInputBlockTracker.scala | 42 --- .../spark/streaming/TestStreamReceiver3.scala | 18 +- .../spark/streaming/TestStreamReceiver4.scala | 16 +- .../src/main/scala/spark/streaming/Time.scala | 100 ++++--- .../examples/DumbTopKWordCount2_Special.scala | 138 ---------- .../examples/DumbWordCount2_Special.scala | 92 ------- .../spark/streaming/examples/ExampleOne.scala | 41 +++ .../spark/streaming/examples/ExampleTwo.scala | 47 ++++ .../scala/spark/streaming/examples/GrepCount.scala | 39 --- .../spark/streaming/examples/GrepCount2.scala | 113 -------- .../spark/streaming/examples/GrepCountApprox.scala | 54 ---- .../spark/streaming/examples/SimpleWordCount.scala | 30 -- .../streaming/examples/SimpleWordCount2.scala | 51 ---- .../examples/SimpleWordCount2_Special.scala | 83 ------ .../spark/streaming/examples/TopContentCount.scala | 97 ------- .../spark/streaming/examples/TopKWordCount2.scala | 103 ------- .../examples/TopKWordCount2_Special.scala | 142 ---------- .../scala/spark/streaming/examples/WordCount.scala | 62 ----- .../spark/streaming/examples/WordCount1.scala | 46 ---- .../spark/streaming/examples/WordCount2.scala | 55 ---- .../streaming/examples/WordCount2_Special.scala | 94 ------- .../spark/streaming/examples/WordCount3.scala | 49 ---- .../spark/streaming/examples/WordCountEc2.scala | 41 --- .../examples/WordCountTrivialWindow.scala | 51 ---- .../scala/spark/streaming/examples/WordMax.scala | 64 ----- .../spark/streaming/util/RecurringTimer.scala | 52 ++++ .../spark/streaming/util/SenderReceiverTest.scala | 64 +++++ .../streaming/util/SentenceFileGenerator.scala | 92 +++++++ .../scala/spark/streaming/util/ShuffleTest.scala | 23 ++ .../main/scala/spark/streaming/util/Utils.scala | 9 + .../utils/SenGeneratorForPerformanceTest.scala | 78 ------ .../spark/streaming/utils/SenderReceiverTest.scala | 63 ----- .../streaming/utils/SentenceFileGenerator.scala | 92 ------- .../spark/streaming/utils/SentenceGenerator.scala | 103 ------- .../scala/spark/streaming/utils/ShuffleTest.scala | 22 -- .../src/test/scala/spark/streaming/RDSSuite.scala | 65 +++++ 50 files changed, 971 insertions(+), 2607 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/BlockID.scala create mode 100644 streaming/src/main/scala/spark/streaming/FileInputRDS.scala delete mode 100644 streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala delete mode 100644 streaming/src/main/scala/spark/streaming/IdealPerformance.scala delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala create mode 100644 streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala create mode 100644 streaming/src/main/scala/spark/streaming/QueueInputRDS.scala delete mode 100644 streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount1.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount3.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordMax.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/Utils.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala create mode 100644 streaming/src/test/scala/spark/streaming/RDSSuite.scala (limited to 'core') diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 5eda1011f9..1d33f7d6b3 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -185,7 +185,7 @@ object Utils { * millisecond. */ def getUsedTimeMs(startTimeMs: Long): String = { - return " " + (System.currentTimeMillis - startTimeMs) + " ms " + return " " + (System.currentTimeMillis - startTimeMs) + " ms" } /** diff --git a/streaming/src/main/scala/spark/streaming/BlockID.scala b/streaming/src/main/scala/spark/streaming/BlockID.scala deleted file mode 100644 index 16aacfda18..0000000000 --- a/streaming/src/main/scala/spark/streaming/BlockID.scala +++ /dev/null @@ -1,20 +0,0 @@ -package spark.streaming - -case class BlockID(sRds: String, sInterval: Interval, sPartition: Int) { - override def toString : String = ( - sRds + BlockID.sConnector + - sInterval.beginTime + BlockID.sConnector + - sInterval.endTime + BlockID.sConnector + - sPartition - ) -} - -object BlockID { - val sConnector = '-' - - def parse(name : String) = BlockID( - name.split(BlockID.sConnector)(0), - new Interval(name.split(BlockID.sConnector)(1).toLong, - name.split(BlockID.sConnector)(2).toLong), - name.split(BlockID.sConnector)(3).toInt) -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/FileInputRDS.scala b/streaming/src/main/scala/spark/streaming/FileInputRDS.scala new file mode 100644 index 0000000000..dde80cd27a --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/FileInputRDS.scala @@ -0,0 +1,163 @@ +package spark.streaming + +import spark.SparkContext +import spark.RDD +import spark.BlockRDD +import spark.UnionRDD +import spark.storage.StorageLevel +import spark.streaming._ + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import java.net.InetSocketAddress + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.PathFilter +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} + + +class FileInputRDS[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( + ssc: SparkStreamContext, + directory: Path, + filter: PathFilter = FileInputRDS.defaultPathFilter, + newFilesOnly: Boolean = true) + extends InputRDS[(K, V)](ssc) { + + val fs = directory.getFileSystem(new Configuration()) + var lastModTime: Long = 0 + + override def start() { + if (newFilesOnly) { + lastModTime = System.currentTimeMillis() + } else { + lastModTime = 0 + } + } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val newFilter = new PathFilter() { + var latestModTime = 0L + + def accept(path: Path): Boolean = { + + if (!filter.accept(path)) { + return false + } else { + val modTime = fs.getFileStatus(path).getModificationTime() + if (modTime < lastModTime) { + return false + } + if (modTime > latestModTime) { + latestModTime = modTime + } + return true + } + } + } + + val newFiles = fs.listStatus(directory, newFilter) + lastModTime = newFilter.latestModTime + val newRDD = new UnionRDD(ssc.sc, newFiles.map(file => + ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString)) + ) + Some(newRDD) + } +} + +object FileInputRDS { + val defaultPathFilter = new PathFilter { + def accept(path: Path): Boolean = { + val file = path.getName() + if (file.startsWith(".") || file.endsWith("_tmp")) { + return false + } else { + return true + } + } + } +} + +/* +class NetworkInputRDS[T: ClassManifest]( + val networkInputName: String, + val addresses: Array[InetSocketAddress], + batchDuration: Time, + ssc: SparkStreamContext) +extends InputRDS[T](networkInputName, batchDuration, ssc) { + + + // TODO(Haoyuan): This is for the performance test. + @transient var rdd: RDD[T] = null + + if (System.getProperty("spark.fake", "false") == "true") { + logInfo("Running initial count to cache fake RDD") + rdd = ssc.sc.textFile(SparkContext.inputFile, + SparkContext.idealPartitions).asInstanceOf[RDD[T]] + val fakeCacheLevel = System.getProperty("spark.fake.cache", "") + if (fakeCacheLevel == "MEMORY_ONLY_2") { + rdd.persist(StorageLevel.MEMORY_ONLY_2) + } else if (fakeCacheLevel == "MEMORY_ONLY_DESER_2") { + rdd.persist(StorageLevel.MEMORY_ONLY_2) + } else if (fakeCacheLevel != "") { + logError("Invalid fake cache level: " + fakeCacheLevel) + System.exit(1) + } + rdd.count() + } + + @transient val references = new HashMap[Time,String] + + override def compute(validTime: Time): Option[RDD[T]] = { + if (System.getProperty("spark.fake", "false") == "true") { + logInfo("Returning fake RDD at " + validTime) + return Some(rdd) + } + references.get(validTime) match { + case Some(reference) => + if (reference.startsWith("file") || reference.startsWith("hdfs")) { + logInfo("Reading from file " + reference + " for time " + validTime) + Some(ssc.sc.textFile(reference).asInstanceOf[RDD[T]]) + } else { + logInfo("Getting from BlockManager " + reference + " for time " + validTime) + Some(new BlockRDD(ssc.sc, Array(reference))) + } + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + references += ((time, reference.toString)) + logInfo("Reference added for time " + time + " - " + reference.toString) + } +} + + +class TestInputRDS( + val testInputName: String, + batchDuration: Time, + ssc: SparkStreamContext) +extends InputRDS[String](testInputName, batchDuration, ssc) { + + @transient val references = new HashMap[Time,Array[String]] + + override def compute(validTime: Time): Option[RDD[String]] = { + references.get(validTime) match { + case Some(reference) => + Some(new BlockRDD[String](ssc.sc, reference)) + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + references += ((time, reference.asInstanceOf[Array[String]])) + } +} +*/ diff --git a/streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala b/streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala deleted file mode 100644 index 92c7cfe00c..0000000000 --- a/streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala +++ /dev/null @@ -1,70 +0,0 @@ -package spark.streaming - -import spark.Logging - -import scala.collection.mutable.HashSet -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -class FileStreamReceiver ( - inputName: String, - rootDirectory: String, - intervalDuration: Long) - extends Logging { - - val pollInterval = 100 - val sparkstreamScheduler = { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt + 1 - RemoteActor.select(Node(host, port), 'SparkStreamScheduler) - } - val directory = new Path(rootDirectory) - val fs = directory.getFileSystem(new Configuration()) - val files = new HashSet[String]() - var time: Long = 0 - - def start() { - fs.mkdirs(directory) - files ++= getFiles() - - actor { - logInfo("Monitoring directory - " + rootDirectory) - while(true) { - testFiles(getFiles()) - Thread.sleep(pollInterval) - } - } - } - - def getFiles(): Iterable[String] = { - fs.listStatus(directory).map(_.getPath.toString) - } - - def testFiles(fileList: Iterable[String]) { - fileList.foreach(file => { - if (!files.contains(file)) { - if (!file.endsWith("_tmp")) { - notifyFile(file) - } - files += file - } - }) - } - - def notifyFile(file: String) { - logInfo("Notifying file " + file) - time += intervalDuration - val interval = Interval(LongTime(time), LongTime(time + intervalDuration)) - sparkstreamScheduler ! InputGenerated(inputName, interval, file) - } -} - - diff --git a/streaming/src/main/scala/spark/streaming/IdealPerformance.scala b/streaming/src/main/scala/spark/streaming/IdealPerformance.scala deleted file mode 100644 index 303d4e7ae6..0000000000 --- a/streaming/src/main/scala/spark/streaming/IdealPerformance.scala +++ /dev/null @@ -1,36 +0,0 @@ -package spark.streaming - -import scala.collection.mutable.Map - -object IdealPerformance { - val base: String = "The medium researcher counts around the pinched troop The empire breaks " + - "Matei Matei announces HY with a theorem " - - def main (args: Array[String]) { - val sentences: String = base * 100000 - - for (i <- 1 to 30) { - val start = System.nanoTime - - val words = sentences.split(" ") - - val pairs = words.map(word => (word, 1)) - - val counts = Map[String, Int]() - - println("Job " + i + " position A at " + (System.nanoTime - start) / 1e9) - - pairs.foreach((pair) => { - var t = counts.getOrElse(pair._1, 0) - counts(pair._1) = t + pair._2 - }) - println("Job " + i + " position B at " + (System.nanoTime - start) / 1e9) - - for ((word, count) <- counts) { - print(word + " " + count + "; ") - } - println - println("Job " + i + " finished in " + (System.nanoTime - start) / 1e9) - } - } -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala index 9a61d85274..1960097216 100644 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -2,7 +2,7 @@ package spark.streaming case class Interval (val beginTime: Time, val endTime: Time) { - def this(beginMs: Long, endMs: Long) = this(new LongTime(beginMs), new LongTime(endMs)) + def this(beginMs: Long, endMs: Long) = this(Time(beginMs), new Time(endMs)) def duration(): Time = endTime - beginTime @@ -44,8 +44,8 @@ object Interval { def zero() = new Interval (Time.zero, Time.zero) - def currentInterval(intervalDuration: LongTime): Interval = { - val time = LongTime(System.currentTimeMillis) + def currentInterval(intervalDuration: Time): Interval = { + val time = Time(System.currentTimeMillis) val intervalBegin = time.floor(intervalDuration) Interval(intervalBegin, intervalBegin + intervalDuration) } diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala index f7654dff79..36958dafe1 100644 --- a/streaming/src/main/scala/spark/streaming/Job.scala +++ b/streaming/src/main/scala/spark/streaming/Job.scala @@ -1,13 +1,14 @@ package spark.streaming +import spark.streaming.util.Utils + class Job(val time: Time, func: () => _) { val id = Job.getNewId() - - def run() { - func() + def run(): Long = { + Utils.time { func() } } - override def toString = "SparkStream Job " + id + ":" + time + override def toString = "streaming job " + id + " @ " + time } object Job { diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index d7d88a7000..43d167f7db 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -1,6 +1,7 @@ package spark.streaming -import spark.{Logging, SparkEnv} +import spark.Logging +import spark.SparkEnv import java.util.concurrent.Executors @@ -10,19 +11,14 @@ class JobManager(ssc: SparkStreamContext, numThreads: Int = 1) extends Logging { def run() { SparkEnv.set(ssc.env) try { - logInfo("Starting " + job) - job.run() - logInfo("Finished " + job) - if (job.time.isInstanceOf[LongTime]) { - val longTime = job.time.asInstanceOf[LongTime] - logInfo("Total notification + skew + processing delay for " + longTime + " is " + - (System.currentTimeMillis - longTime.milliseconds) / 1000.0 + " s") - if (System.getProperty("spark.stream.distributed", "false") == "true") { - TestInputBlockTracker.setEndTime(job.time) - } - } + val timeTaken = job.run() + logInfo( + "Runnning " + job + " took " + timeTaken + " ms, " + + "total delay was " + (System.currentTimeMillis - job.time) + " ms" + ) } catch { - case e: Exception => logError("SparkStream job failed", e) + case e: Exception => + logError("Running " + job + " failed", e) } } } @@ -33,5 +29,6 @@ class JobManager(ssc: SparkStreamContext, numThreads: Int = 1) extends Logging { def runJob(job: Job) { jobExecutor.execute(new JobHandler(ssc, job)) + logInfo("Added " + job + " to queue") } } diff --git a/streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala b/streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala deleted file mode 100644 index efd4689cf0..0000000000 --- a/streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala +++ /dev/null @@ -1,184 +0,0 @@ -package spark.streaming - -import spark.Logging -import spark.storage.StorageLevel - -import scala.math._ -import scala.collection.mutable.{Queue, HashMap, ArrayBuffer} -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.io.BufferedWriter -import java.io.OutputStreamWriter - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -/*import akka.actor.Actor._*/ - -class NetworkStreamReceiver[T: ClassManifest] ( - inputName: String, - intervalDuration: Time, - splitId: Int, - ssc: SparkStreamContext, - tempDirectory: String) - extends DaemonActor - with Logging { - - /** - * Assume all data coming in has non-decreasing timestamp. - */ - final class Inbox[T: ClassManifest] (intervalDuration: Time) { - var currentBucket: (Interval, ArrayBuffer[T]) = null - val filledBuckets = new Queue[(Interval, ArrayBuffer[T])]() - - def += (tuple: (Time, T)) = addTuple(tuple) - - def addTuple(tuple: (Time, T)) { - val (time, data) = tuple - val interval = getInterval (time) - - filledBuckets.synchronized { - if (currentBucket == null) { - currentBucket = (interval, new ArrayBuffer[T]()) - } - - if (interval != currentBucket._1) { - filledBuckets += currentBucket - currentBucket = (interval, new ArrayBuffer[T]()) - } - - currentBucket._2 += data - } - } - - def getInterval(time: Time): Interval = { - val intervalBegin = time.floor(intervalDuration) - Interval (intervalBegin, intervalBegin + intervalDuration) - } - - def hasFilledBuckets(): Boolean = { - filledBuckets.synchronized { - return filledBuckets.size > 0 - } - } - - def popFilledBucket(): (Interval, ArrayBuffer[T]) = { - filledBuckets.synchronized { - if (filledBuckets.size == 0) { - return null - } - return filledBuckets.dequeue() - } - } - } - - val inbox = new Inbox[T](intervalDuration) - lazy val sparkstreamScheduler = { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt - val url = "akka://spark@%s:%s/user/SparkStreamScheduler".format(host, port) - ssc.actorSystem.actorFor(url) - } - /*sparkstreamScheduler ! Test()*/ - - val intervalDurationMillis = intervalDuration.asInstanceOf[LongTime].milliseconds - val useBlockManager = true - - initLogging() - - override def act() { - // register the InputReceiver - val port = 7078 - RemoteActor.alive(port) - RemoteActor.register(Symbol("NetworkStreamReceiver-"+inputName), self) - logInfo("Registered actor on port " + port) - - loop { - reactWithin (getSleepTime) { - case TIMEOUT => - flushInbox() - case data => - val t = data.asInstanceOf[T] - inbox += (getTimeFromData(t), t) - } - } - } - - def getSleepTime(): Long = { - (System.currentTimeMillis / intervalDurationMillis + 1) * - intervalDurationMillis - System.currentTimeMillis - } - - def getTimeFromData(data: T): Time = { - LongTime(System.currentTimeMillis) - } - - def flushInbox() { - while (inbox.hasFilledBuckets) { - inbox.synchronized { - val (interval, data) = inbox.popFilledBucket() - val dataArray = data.toArray - logInfo("Received " + dataArray.length + " items at interval " + interval) - val reference = { - if (useBlockManager) { - writeToBlockManager(dataArray, interval) - } else { - writeToDisk(dataArray, interval) - } - } - if (reference != null) { - logInfo("Notifying scheduler") - sparkstreamScheduler ! InputGenerated(inputName, interval, reference.toString) - } - } - } - } - - def writeToDisk(data: Array[T], interval: Interval): String = { - try { - // TODO(Haoyuan): For current test, the following writing to file lines could be - // commented. - val fs = new Path(tempDirectory).getFileSystem(new Configuration()) - val inputDir = new Path( - tempDirectory, - inputName + "-" + interval.toFormattedString) - val inputFile = new Path(inputDir, "part-" + splitId) - logInfo("Writing to file " + inputFile) - if (System.getProperty("spark.fake", "false") != "true") { - val writer = new BufferedWriter(new OutputStreamWriter(fs.create(inputFile, true))) - data.foreach(x => writer.write(x.toString + "\n")) - writer.close() - } else { - logInfo("Fake file") - } - inputFile.toString - }catch { - case e: Exception => - logError("Exception writing to file at interval " + interval + ": " + e.getMessage, e) - null - } - } - - def writeToBlockManager(data: Array[T], interval: Interval): String = { - try{ - val blockId = inputName + "-" + interval.toFormattedString + "-" + splitId - if (System.getProperty("spark.fake", "false") != "true") { - logInfo("Writing as block " + blockId ) - ssc.env.blockManager.put(blockId.toString, data.toIterator, StorageLevel.DISK_AND_MEMORY) - } else { - logInfo("Fake block") - } - blockId - } catch { - case e: Exception => - logError("Exception writing to block manager at interval " + interval + ": " + e.getMessage, e) - null - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala b/streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala new file mode 100644 index 0000000000..403ae233a5 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala @@ -0,0 +1,72 @@ +package spark.streaming + +import scala.collection.mutable.ArrayBuffer +import spark.streaming.SparkStreamContext._ + +class PairRDSFunctions[K: ClassManifest, V: ClassManifest](rds: RDS[(K,V)]) +extends Serializable { + + def ssc = rds.ssc + + /* ---------------------------------- */ + /* RDS operations for key-value pairs */ + /* ---------------------------------- */ + + def groupByKey(numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { + def createCombiner(v: V) = ArrayBuffer[V](v) + def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) + def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) + combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, numPartitions) + } + + def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int = 0): ShuffledRDS[K, V, V] = { + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, numPartitions) + } + + private def combineByKey[C: ClassManifest]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + numPartitions: Int) : ShuffledRDS[K, V, C] = { + new ShuffledRDS[K, V, C](rds, createCombiner, mergeValue, mergeCombiner, numPartitions) + } + + def groupByKeyAndWindow( + windowTime: Time, + slideTime: Time, + numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { + rds.window(windowTime, slideTime).groupByKey(numPartitions) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int = 0): ShuffledRDS[K, V, V] = { + rds.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), numPartitions) + } + + // This method is the efficient sliding window reduce operation, + // which requires the specification of an inverse reduce function, + // so that new elements introduced in the window can be "added" using + // reduceFunc to the previous window's result and old elements can be + // "subtracted using invReduceFunc. + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int): ReducedWindowedRDS[K, V] = { + + new ReducedWindowedRDS[K, V]( + rds, + ssc.sc.clean(reduceFunc), + ssc.sc.clean(invReduceFunc), + windowTime, + slideTime, + numPartitions) + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/QueueInputRDS.scala b/streaming/src/main/scala/spark/streaming/QueueInputRDS.scala new file mode 100644 index 0000000000..31e6a64e21 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/QueueInputRDS.scala @@ -0,0 +1,36 @@ +package spark.streaming + +import spark.RDD +import spark.UnionRDD + +import scala.collection.mutable.Queue +import scala.collection.mutable.ArrayBuffer + +class QueueInputRDS[T: ClassManifest]( + ssc: SparkStreamContext, + val queue: Queue[RDD[T]], + oneAtATime: Boolean, + defaultRDD: RDD[T] + ) extends InputRDS[T](ssc) { + + override def start() { } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[T]] = { + val buffer = new ArrayBuffer[RDD[T]]() + if (oneAtATime && queue.size > 0) { + buffer += queue.dequeue() + } else { + buffer ++= queue + } + if (buffer.size > 0) { + Some(new UnionRDD(ssc.sc, buffer.toSeq)) + } else if (defaultRDD != null) { + Some(defaultRDD) + } else { + None + } + } + +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/RDS.scala b/streaming/src/main/scala/spark/streaming/RDS.scala index c8dd1015ed..fd923929e7 100644 --- a/streaming/src/main/scala/spark/streaming/RDS.scala +++ b/streaming/src/main/scala/spark/streaming/RDS.scala @@ -13,16 +13,18 @@ import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import java.net.InetSocketAddress +import java.util.concurrent.ArrayBlockingQueue abstract class RDS[T: ClassManifest] (@transient val ssc: SparkStreamContext) extends Logging with Serializable { initLogging() - /* ---------------------------------------------- */ - /* Methods that must be implemented by subclasses */ - /* ---------------------------------------------- */ + /** + * ---------------------------------------------- + * Methods that must be implemented by subclasses + * ---------------------------------------------- + */ // Time by which the window slides in this RDS def slideTime: Time @@ -33,9 +35,11 @@ extends Logging with Serializable { // Key method that computes RDD for a valid time def compute (validTime: Time): Option[RDD[T]] - /* --------------------------------------- */ - /* Other general fields and methods of RDS */ - /* --------------------------------------- */ + /** + * --------------------------------------- + * Other general fields and methods of RDS + * --------------------------------------- + */ // Variable to store the RDDs generated earlier in time @transient private val generatedRDDs = new HashMap[Time, RDD[T]] () @@ -66,9 +70,9 @@ extends Logging with Serializable { this } + // Set caching level for the RDDs created by this RDS def persist(newLevel: StorageLevel): RDS[T] = persist(newLevel, StorageLevel.NONE, null) - // Turn on the default caching level for this RDD def persist(): RDS[T] = persist(StorageLevel.MEMORY_ONLY_DESER) // Turn on the default caching level for this RDD @@ -76,18 +80,20 @@ extends Logging with Serializable { def isInitialized = (zeroTime != null) - // This method initializes the RDS by setting the "zero" time, based on which - // the validity of future times is calculated. This method also recursively initializes - // its parent RDSs. - def initialize(firstInterval: Interval) { + /** + * This method initializes the RDS by setting the "zero" time, based on which + * the validity of future times is calculated. This method also recursively initializes + * its parent RDSs. + */ + def initialize(time: Time) { if (zeroTime == null) { - zeroTime = firstInterval.beginTime + zeroTime = time } logInfo(this + " initialized") - dependencies.foreach(_.initialize(firstInterval)) + dependencies.foreach(_.initialize(zeroTime)) } - // This method checks whether the 'time' is valid wrt slideTime for generating RDD + /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ private def isTimeValid (time: Time): Boolean = { if (!isInitialized) throw new Exception (this.toString + " has not been initialized") @@ -98,11 +104,13 @@ extends Logging with Serializable { } } - // This method either retrieves a precomputed RDD of this RDS, - // or computes the RDD (if the time is valid) + /** + * This method either retrieves a precomputed RDD of this RDS, + * or computes the RDD (if the time is valid) + */ def getOrCompute(time: Time): Option[RDD[T]] = { - - // if RDD was already generated, then retrieve it from HashMap + // If this RDS was not initialized (i.e., zeroTime not set), then do it + // If RDD was already generated, then retrieve it from HashMap generatedRDDs.get(time) match { // If an RDD was already generated and is being reused, then @@ -115,15 +123,12 @@ extends Logging with Serializable { if (isTimeValid(time)) { compute(time) match { case Some(newRDD) => - if (System.getProperty("spark.fake", "false") != "true" || - newRDD.getStorageLevel == StorageLevel.NONE) { - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) - } + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.persist(checkpointLevel) + logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) + } else if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) } generatedRDDs.put(time.copy(), newRDD) Some(newRDD) @@ -136,8 +141,10 @@ extends Logging with Serializable { } } - // This method generates a SparkStream job for the given time - // and may require to be overriden by subclasses + /** + * This method generates a SparkStream job for the given time + * and may require to be overriden by subclasses + */ def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { case Some(rdd) => { @@ -151,9 +158,11 @@ extends Logging with Serializable { } } - /* -------------- */ - /* RDS operations */ - /* -------------- */ + /** + * -------------- + * RDS operations + * -------------- + */ def map[U: ClassManifest](mapFunc: T => U) = new MappedRDS(this, ssc.sc.clean(mapFunc)) @@ -185,6 +194,15 @@ extends Logging with Serializable { newrds } + private[streaming] def toQueue() = { + val queue = new ArrayBlockingQueue[RDD[T]](10000) + this.foreachRDD(rdd => { + println("Added RDD " + rdd.id) + queue.add(rdd) + }) + queue + } + def print() = { def foreachFunc = (rdd: RDD[T], time: Time) => { val first11 = rdd.take(11) @@ -229,198 +247,23 @@ extends Logging with Serializable { } -class PairRDSFunctions[K: ClassManifest, V: ClassManifest](rds: RDS[(K,V)]) -extends Serializable { - - def ssc = rds.ssc - - /* ---------------------------------- */ - /* RDS operations for key-value pairs */ - /* ---------------------------------- */ - - def groupByKey(numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { - def createCombiner(v: V) = ArrayBuffer[V](v) - def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) - def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) - combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, numPartitions) - } - - def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int = 0): ShuffledRDS[K, V, V] = { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, numPartitions) - } - - private def combineByKey[C: ClassManifest]( - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - numPartitions: Int) : ShuffledRDS[K, V, C] = { - new ShuffledRDS[K, V, C](rds, createCombiner, mergeValue, mergeCombiner, numPartitions) - } - - def groupByKeyAndWindow( - windowTime: Time, - slideTime: Time, - numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { - rds.window(windowTime, slideTime).groupByKey(numPartitions) - } - - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, - numPartitions: Int = 0): ShuffledRDS[K, V, V] = { - rds.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), numPartitions) - } - - // This method is the efficient sliding window reduce operation, - // which requires the specification of an inverse reduce function, - // so that new elements introduced in the window can be "added" using - // reduceFunc to the previous window's result and old elements can be - // "subtracted using invReduceFunc. - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, - numPartitions: Int): ReducedWindowedRDS[K, V] = { - - new ReducedWindowedRDS[K, V]( - rds, - ssc.sc.clean(reduceFunc), - ssc.sc.clean(invReduceFunc), - windowTime, - slideTime, - numPartitions) - } -} - - abstract class InputRDS[T: ClassManifest] ( - val inputName: String, - val batchDuration: Time, ssc: SparkStreamContext) extends RDS[T](ssc) { override def dependencies = List() - override def slideTime = batchDuration + override def slideTime = ssc.batchDuration - def setReference(time: Time, reference: AnyRef) -} - - -class FileInputRDS( - val fileInputName: String, - val directory: String, - ssc: SparkStreamContext) -extends InputRDS[String](fileInputName, LongTime(1000), ssc) { - - @transient val generatedFiles = new HashMap[Time,String] - - // TODO(Haoyuan): This is for the performance test. - @transient - val rdd = ssc.sc.textFile(SparkContext.inputFile, - SparkContext.idealPartitions).asInstanceOf[RDD[String]] + def start() - override def compute(validTime: Time): Option[RDD[String]] = { - generatedFiles.get(validTime) match { - case Some(file) => - logInfo("Reading from file " + file + " for time " + validTime) - // Some(ssc.sc.textFile(file).asInstanceOf[RDD[String]]) - // The following line is for HDFS performance test. Sould comment out the above line. - Some(rdd) - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - generatedFiles += ((time, reference.toString)) - logInfo("Reference added for time " + time + " - " + reference.toString) - } -} - -class NetworkInputRDS[T: ClassManifest]( - val networkInputName: String, - val addresses: Array[InetSocketAddress], - batchDuration: Time, - ssc: SparkStreamContext) -extends InputRDS[T](networkInputName, batchDuration, ssc) { - - - // TODO(Haoyuan): This is for the performance test. - @transient var rdd: RDD[T] = null - - if (System.getProperty("spark.fake", "false") == "true") { - logInfo("Running initial count to cache fake RDD") - rdd = ssc.sc.textFile(SparkContext.inputFile, - SparkContext.idealPartitions).asInstanceOf[RDD[T]] - val fakeCacheLevel = System.getProperty("spark.fake.cache", "") - if (fakeCacheLevel == "MEMORY_ONLY_2") { - rdd.persist(StorageLevel.MEMORY_ONLY_2) - } else if (fakeCacheLevel == "MEMORY_ONLY_DESER_2") { - rdd.persist(StorageLevel.MEMORY_ONLY_2) - } else if (fakeCacheLevel != "") { - logError("Invalid fake cache level: " + fakeCacheLevel) - System.exit(1) - } - rdd.count() - } - - @transient val references = new HashMap[Time,String] - - override def compute(validTime: Time): Option[RDD[T]] = { - if (System.getProperty("spark.fake", "false") == "true") { - logInfo("Returning fake RDD at " + validTime) - return Some(rdd) - } - references.get(validTime) match { - case Some(reference) => - if (reference.startsWith("file") || reference.startsWith("hdfs")) { - logInfo("Reading from file " + reference + " for time " + validTime) - Some(ssc.sc.textFile(reference).asInstanceOf[RDD[T]]) - } else { - logInfo("Getting from BlockManager " + reference + " for time " + validTime) - Some(new BlockRDD(ssc.sc, Array(reference))) - } - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - references += ((time, reference.toString)) - logInfo("Reference added for time " + time + " - " + reference.toString) - } + def stop() } -class TestInputRDS( - val testInputName: String, - batchDuration: Time, - ssc: SparkStreamContext) -extends InputRDS[String](testInputName, batchDuration, ssc) { - - @transient val references = new HashMap[Time,Array[String]] - - override def compute(validTime: Time): Option[RDD[String]] = { - references.get(validTime) match { - case Some(reference) => - Some(new BlockRDD[String](ssc.sc, reference)) - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - references += ((time, reference.asInstanceOf[Array[String]])) - } -} - +/** + * TODO + */ class MappedRDS[T: ClassManifest, U: ClassManifest] ( parent: RDS[T], @@ -437,6 +280,10 @@ extends RDS[U](parent.ssc) { } +/** + * TODO + */ + class FlatMappedRDS[T: ClassManifest, U: ClassManifest]( parent: RDS[T], flatMapFunc: T => Traversable[U]) @@ -452,6 +299,10 @@ extends RDS[U](parent.ssc) { } +/** + * TODO + */ + class FilteredRDS[T: ClassManifest](parent: RDS[T], filterFunc: T => Boolean) extends RDS[T](parent.ssc) { @@ -464,6 +315,11 @@ extends RDS[T](parent.ssc) { } } + +/** + * TODO + */ + class MapPartitionedRDS[T: ClassManifest, U: ClassManifest]( parent: RDS[T], mapPartFunc: Iterator[T] => Iterator[U]) @@ -478,6 +334,11 @@ extends RDS[U](parent.ssc) { } } + +/** + * TODO + */ + class GlommedRDS[T: ClassManifest](parent: RDS[T]) extends RDS[Array[T]](parent.ssc) { override def dependencies = List(parent) @@ -490,6 +351,10 @@ class GlommedRDS[T: ClassManifest](parent: RDS[T]) extends RDS[Array[T]](parent. } +/** + * TODO + */ + class ShuffledRDS[K: ClassManifest, V: ClassManifest, C: ClassManifest]( parent: RDS[(K,V)], createCombiner: V => C, @@ -519,6 +384,10 @@ class ShuffledRDS[K: ClassManifest, V: ClassManifest, C: ClassManifest]( } +/** + * TODO + */ + class UnifiedRDS[T: ClassManifest](parents: Array[RDS[T]]) extends RDS[T](parents(0).ssc) { @@ -553,6 +422,10 @@ extends RDS[T](parents(0).ssc) { } +/** + * TODO + */ + class PerElementForEachRDS[T: ClassManifest] ( parent: RDS[T], foreachFunc: T => Unit) @@ -580,6 +453,10 @@ extends RDS[Unit](parent.ssc) { } +/** + * TODO + */ + class PerRDDForEachRDS[T: ClassManifest] ( parent: RDS[T], foreachFunc: (RDD[T], Time) => Unit) diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 8df346559c..83f874e550 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -1,16 +1,11 @@ package spark.streaming +import spark.streaming.util.RecurringTimer import spark.SparkEnv import spark.Logging import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.collection.mutable.ArrayBuffer -import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ -import akka.util.duration._ sealed trait SchedulerMessage case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage @@ -20,162 +15,42 @@ class Scheduler( ssc: SparkStreamContext, inputRDSs: Array[InputRDS[_]], outputRDSs: Array[RDS[_]]) -extends Actor with Logging { - - class InputState (inputNames: Array[String]) { - val inputsLeft = new HashSet[String]() - inputsLeft ++= inputNames - - val startTime = System.currentTimeMillis - - def delay() = System.currentTimeMillis - startTime - - def addGeneratedInput(inputName: String) = inputsLeft -= inputName - - def areAllInputsGenerated() = (inputsLeft.size == 0) - - override def toString(): String = { - val left = if (inputsLeft.size == 0) "" else inputsLeft.reduceLeft(_ + ", " + _) - return "Inputs left = [ " + left + " ]" - } - } - +extends Logging { initLogging() - val inputNames = inputRDSs.map(_.inputName).toArray - val inputStates = new HashMap[Interval, InputState]() - val currentJobs = System.getProperty("spark.stream.currentJobs", "1").toInt - val jobManager = new JobManager(ssc, currentJobs) - - // TODO(Haoyuan): The following line is for performance test only. - var cnt: Int = System.getProperty("spark.stream.fake.cnt", "60").toInt - var lastInterval: Interval = null - - - /*remote.register("SparkStreamScheduler", actorOf[Scheduler])*/ - logInfo("Registered actor on port ") - - /*jobManager.start()*/ - startStreamReceivers() - - def receive = { - case InputGenerated(inputName, interval, reference) => { - addGeneratedInput(inputName, interval, reference) - } - case Test() => logInfo("TEST PASSED") - } - - def addGeneratedInput(inputName: String, interval: Interval, reference: AnyRef = null) { - logInfo("Input " + inputName + " generated for interval " + interval) - inputStates.get(interval) match { - case None => inputStates.put(interval, new InputState(inputNames)) - case _ => - } - inputStates(interval).addGeneratedInput(inputName) + val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt + val jobManager = new JobManager(ssc, concurrentJobs) + val timer = new RecurringTimer(ssc.batchDuration, generateRDDs(_)) - inputRDSs.filter(_.inputName == inputName).foreach(inputRDS => { - inputRDS.setReference(interval.endTime, reference) - if (inputRDS.isInstanceOf[TestInputRDS]) { - TestInputBlockTracker.addBlocks(interval.endTime, reference) - } - } - ) - - def getNextInterval(): Option[Interval] = { - logDebug("Last interval is " + lastInterval) - val readyIntervals = inputStates.filter(_._2.areAllInputsGenerated).keys - /*inputState.foreach(println) */ - logDebug("InputState has " + inputStates.size + " intervals, " + readyIntervals.size + " ready intervals") - return readyIntervals.find(lastInterval == null || _.beginTime == lastInterval.endTime) - } - - var nextInterval = getNextInterval() - var count = 0 - while(nextInterval.isDefined) { - val inputState = inputStates.get(nextInterval.get).get - generateRDDsForInterval(nextInterval.get) - logInfo("Skew delay for " + nextInterval.get.endTime + " is " + (inputState.delay / 1000.0) + " s") - inputStates.remove(nextInterval.get) - lastInterval = nextInterval.get - nextInterval = getNextInterval() - count += 1 - /*if (nextInterval.size == 0 && inputState.size > 0) { - logDebug("Next interval not ready, pending intervals " + inputState.size) - }*/ - } - logDebug("RDDs generated for " + count + " intervals") - - /* - if (inputState(interval).areAllInputsGenerated) { - generateRDDsForInterval(interval) - lastInterval = interval - inputState.remove(interval) - } else { - logInfo("All inputs not generated for interval " + interval) - } - */ + def start() { + + val zeroTime = Time(timer.start()) + outputRDSs.foreach(_.initialize(zeroTime)) + inputRDSs.par.foreach(_.start()) + logInfo("Scheduler started") } - - def generateRDDsForInterval (interval: Interval) { - logInfo("Generating RDDs for interval " + interval) + + def stop() { + timer.stop() + inputRDSs.par.foreach(_.stop()) + logInfo("Scheduler stopped") + } + + def generateRDDs (time: Time) { + logInfo("Generating RDDs for time " + time) outputRDSs.foreach(outputRDS => { - if (!outputRDS.isInitialized) outputRDS.initialize(interval) - outputRDS.generateJob(interval.endTime) match { + outputRDS.generateJob(time) match { case Some(job) => submitJob(job) case None => } } ) - // TODO(Haoyuan): This comment is for performance test only. - if (System.getProperty("spark.fake", "false") == "true" || System.getProperty("spark.stream.fake", "false") == "true") { - cnt -= 1 - if (cnt <= 0) { - logInfo("My time is up! " + cnt) - System.exit(1) - } - } + logInfo("Generated RDDs for time " + time) } - def submitJob(job: Job) { - logInfo("Submitting " + job + " to JobManager") - /*jobManager ! RunJob(job)*/ + def submitJob(job: Job) { jobManager.runJob(job) } - - def startStreamReceivers() { - val testStreamReceiverNames = new ArrayBuffer[(String, Long)]() - inputRDSs.foreach (inputRDS => { - inputRDS match { - case fileInputRDS: FileInputRDS => { - val fileStreamReceiver = new FileStreamReceiver( - fileInputRDS.inputName, - fileInputRDS.directory, - fileInputRDS.batchDuration.asInstanceOf[LongTime].milliseconds) - fileStreamReceiver.start() - } - case networkInputRDS: NetworkInputRDS[_] => { - val networkStreamReceiver = new NetworkStreamReceiver( - networkInputRDS.inputName, - networkInputRDS.batchDuration, - 0, - ssc, - if (ssc.tempDir == null) null else ssc.tempDir.toString) - networkStreamReceiver.start() - } - case testInputRDS: TestInputRDS => { - testStreamReceiverNames += - ((testInputRDS.inputName, testInputRDS.batchDuration.asInstanceOf[LongTime].milliseconds)) - } - } - }) - if (testStreamReceiverNames.size > 0) { - /*val testStreamCoordinator = new TestStreamCoordinator(testStreamReceiverNames.toArray)*/ - /*testStreamCoordinator.start()*/ - val actor = ssc.actorSystem.actorOf( - Props(new TestStreamCoordinator(testStreamReceiverNames.toArray)), - name = "TestStreamCoordinator") - } - } } diff --git a/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala b/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala index 51f8193740..d32f6d588c 100644 --- a/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala +++ b/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala @@ -1,22 +1,23 @@ package spark.streaming -import spark.SparkContext -import spark.SparkEnv -import spark.Utils +import spark.RDD import spark.Logging +import spark.SparkEnv +import spark.SparkContext import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.Queue -import java.net.InetSocketAddress import java.io.IOException -import java.util.UUID +import java.net.InetSocketAddress +import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration - -import akka.actor._ -import akka.actor.Actor -import akka.util.duration._ +import org.apache.hadoop.io.LongWritable +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat class SparkStreamContext ( master: String, @@ -24,30 +25,37 @@ class SparkStreamContext ( val sparkHome: String = null, val jars: Seq[String] = Nil) extends Logging { - + initLogging() val sc = new SparkContext(master, frameworkName, sparkHome, jars) val env = SparkEnv.get - val actorSystem = env.actorSystem - - @transient val inputRDSs = new ArrayBuffer[InputRDS[_]]() - @transient val outputRDSs = new ArrayBuffer[RDS[_]]() - var tempDirRoot: String = null - var tempDir: Path = null - - def readNetworkStream[T: ClassManifest]( + val inputRDSs = new ArrayBuffer[InputRDS[_]]() + val outputRDSs = new ArrayBuffer[RDS[_]]() + var batchDuration: Time = null + var scheduler: Scheduler = null + + def setBatchDuration(duration: Long) { + setBatchDuration(Time(duration)) + } + + def setBatchDuration(duration: Time) { + batchDuration = duration + } + + /* + def createNetworkStream[T: ClassManifest]( name: String, addresses: Array[InetSocketAddress], batchDuration: Time): RDS[T] = { - val inputRDS = new NetworkInputRDS[T](name, addresses, batchDuration, this) + val inputRDS = new NetworkInputRDS[T](this, addresses) inputRDSs += inputRDS inputRDS - } + } - def readNetworkStream[T: ClassManifest]( + def createNetworkStream[T: ClassManifest]( name: String, addresses: Array[String], batchDuration: Long): RDS[T] = { @@ -65,40 +73,100 @@ class SparkStreamContext ( addresses.map(stringToInetSocketAddress).toArray, LongTime(batchDuration)) } - - def readFileStream(name: String, directory: String): RDS[String] = { - val path = new Path(directory) - val fs = path.getFileSystem(new Configuration()) - val qualPath = path.makeQualified(fs) - val inputRDS = new FileInputRDS(name, qualPath.toString, this) + */ + + /** + * This function creates a input stream that monitors a Hadoop-compatible + * for new files and executes the necessary processing on them. + */ + def createFileStream[ + K: ClassManifest, + V: ClassManifest, + F <: NewInputFormat[K, V]: ClassManifest + ](directory: String): RDS[(K, V)] = { + val inputRDS = new FileInputRDS[K, V, F](this, new Path(directory)) inputRDSs += inputRDS inputRDS } - def readTestStream(name: String, batchDuration: Long): RDS[String] = { - val inputRDS = new TestInputRDS(name, LongTime(batchDuration), this) - inputRDSs += inputRDS + def createTextFileStream(directory: String): RDS[String] = { + createFileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) + } + + /** + * This function create a input stream from an queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue + */ + def createQueueStream[T: ClassManifest]( + queue: Queue[RDD[T]], + oneAtATime: Boolean = true, + defaultRDD: RDD[T] = null + ): RDS[T] = { + val inputRDS = new QueueInputRDS(this, queue, oneAtATime, defaultRDD) + inputRDSs += inputRDS inputRDS } + + def createQueueStream[T: ClassManifest](iterator: Iterator[RDD[T]]): RDS[T] = { + val queue = new Queue[RDD[T]] + val inputRDS = createQueueStream(queue, true, null) + queue ++= iterator + inputRDS + } + + /** + * This function registers a RDS as an output stream that will be + * computed every interval. + */ def registerOutputStream (outputRDS: RDS[_]) { outputRDSs += outputRDS } - - def setTempDir(dir: String) { - tempDirRoot = dir + + /** + * This function verify whether the stream computation is eligible to be executed. + */ + def verify() { + if (batchDuration == null) { + throw new Exception("Batch duration has not been set") + } + if (batchDuration < Milliseconds(100)) { + logWarning("Batch duration of " + batchDuration + " is very low") + } + if (inputRDSs.size == 0) { + throw new Exception("No input RDSes created, so nothing to take input from") + } + if (outputRDSs.size == 0) { + throw new Exception("No output RDSes registered, so nothing to execute") + } + } - - def run () { - val ctxt = this - val actor = actorSystem.actorOf( - Props(new Scheduler(ctxt, inputRDSs.toArray, outputRDSs.toArray)), - name = "SparkStreamScheduler") - logInfo("Registered actor") - actorSystem.awaitTermination() + + /** + * This function starts the execution of the streams. + */ + def start() { + verify() + scheduler = new Scheduler(this, inputRDSs.toArray, outputRDSs.toArray) + scheduler.start() + } + + /** + * This function starts the execution of the streams. + */ + def stop() { + try { + scheduler.stop() + sc.stop() + } catch { + case e: Exception => logWarning("Error while stopping", e) + } + + logInfo("SparkStreamContext stopped") } } + object SparkStreamContext { implicit def rdsToPairRdsFunctions [K: ClassManifest, V: ClassManifest] (rds: RDS[(K,V)]) = new PairRDSFunctions (rds) diff --git a/streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala b/streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala deleted file mode 100644 index 7e23b7bb82..0000000000 --- a/streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala +++ /dev/null @@ -1,42 +0,0 @@ -package spark.streaming -import spark.Logging -import scala.collection.mutable.{ArrayBuffer, HashMap} - -object TestInputBlockTracker extends Logging { - initLogging() - val allBlockIds = new HashMap[Time, ArrayBuffer[String]]() - - def addBlocks(intervalEndTime: Time, reference: AnyRef) { - allBlockIds.getOrElseUpdate(intervalEndTime, new ArrayBuffer[String]()) ++= reference.asInstanceOf[Array[String]] - } - - def setEndTime(intervalEndTime: Time) { - try { - val endTime = System.currentTimeMillis - allBlockIds.get(intervalEndTime) match { - case Some(blockIds) => { - val numBlocks = blockIds.size - var totalDelay = 0d - blockIds.foreach(blockId => { - val inputTime = getInputTime(blockId) - val delay = (endTime - inputTime) / 1000.0 - totalDelay += delay - logInfo("End-to-end delay for block " + blockId + " is " + delay + " s") - }) - logInfo("Average end-to-end delay for time " + intervalEndTime + " is " + (totalDelay / numBlocks) + " s") - allBlockIds -= intervalEndTime - } - case None => throw new Exception("Unexpected") - } - } catch { - case e: Exception => logError(e.toString) - } - } - - def getInputTime(blockId: String): Long = { - val parts = blockId.split("-") - /*logInfo(blockId + " -> " + parts(4)) */ - parts(4).toLong - } -} - diff --git a/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala b/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala index a7a5635aa5..bbf2c7bf5e 100644 --- a/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala +++ b/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala @@ -34,8 +34,8 @@ extends Thread with Logging { class DataHandler( inputName: String, - longIntervalDuration: LongTime, - shortIntervalDuration: LongTime, + longIntervalDuration: Time, + shortIntervalDuration: Time, blockManager: BlockManager ) extends Logging { @@ -61,8 +61,8 @@ extends Thread with Logging { initLogging() - val shortIntervalDurationMillis = shortIntervalDuration.asInstanceOf[LongTime].milliseconds - val longIntervalDurationMillis = longIntervalDuration.asInstanceOf[LongTime].milliseconds + val shortIntervalDurationMillis = shortIntervalDuration.toLong + val longIntervalDurationMillis = longIntervalDuration.toLong var currentBlock: Block = null var currentBucket: Bucket = null @@ -101,7 +101,7 @@ extends Thread with Logging { def updateCurrentBlock() { /*logInfo("Updating current block")*/ - val currentTime: LongTime = LongTime(System.currentTimeMillis) + val currentTime = Time(System.currentTimeMillis) val shortInterval = getShortInterval(currentTime) val longInterval = getLongInterval(shortInterval) @@ -318,12 +318,12 @@ extends Thread with Logging { val inputName = streamDetails.name val intervalDurationMillis = streamDetails.duration - val intervalDuration = LongTime(intervalDurationMillis) + val intervalDuration = Time(intervalDurationMillis) val dataHandler = new DataHandler( inputName, intervalDuration, - LongTime(TestStreamReceiver3.SHORT_INTERVAL_MILLIS), + Time(TestStreamReceiver3.SHORT_INTERVAL_MILLIS), blockManager) val connListener = new ConnectionListener(TestStreamReceiver3.PORT, dataHandler) @@ -382,7 +382,7 @@ extends Thread with Logging { def waitFor(time: Time) { val currentTimeMillis = System.currentTimeMillis - val targetTimeMillis = time.asInstanceOf[LongTime].milliseconds + val targetTimeMillis = time.milliseconds if (currentTimeMillis < targetTimeMillis) { val sleepTime = (targetTimeMillis - currentTimeMillis) Thread.sleep(sleepTime + 1) @@ -392,7 +392,7 @@ extends Thread with Logging { def notifyScheduler(interval: Interval, blockIds: Array[String]) { try { sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) - val time = interval.endTime.asInstanceOf[LongTime] + val time = interval.endTime val delay = (System.currentTimeMillis - time.milliseconds) / 1000.0 logInfo("Pushing delay for " + time + " is " + delay + " s") } catch { diff --git a/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala b/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala index 2c3f5d1b9d..a2babb23f4 100644 --- a/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala +++ b/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala @@ -24,8 +24,8 @@ extends Thread with Logging { class DataHandler( inputName: String, - longIntervalDuration: LongTime, - shortIntervalDuration: LongTime, + longIntervalDuration: Time, + shortIntervalDuration: Time, blockManager: BlockManager ) extends Logging { @@ -50,8 +50,8 @@ extends Thread with Logging { val syncOnLastShortInterval = true - val shortIntervalDurationMillis = shortIntervalDuration.asInstanceOf[LongTime].milliseconds - val longIntervalDurationMillis = longIntervalDuration.asInstanceOf[LongTime].milliseconds + val shortIntervalDurationMillis = shortIntervalDuration.milliseconds + val longIntervalDurationMillis = longIntervalDuration.milliseconds val buffer = ByteBuffer.allocateDirect(100 * 1024 * 1024) var currentShortInterval = Interval.currentInterval(shortIntervalDuration) @@ -145,7 +145,7 @@ extends Thread with Logging { if (syncOnLastShortInterval) { bucket += newBlock } - logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.asInstanceOf[LongTime].milliseconds) / 1000.0 + " s" ) + logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.milliseconds) / 1000.0 + " s" ) blockPushingExecutor.execute(new Runnable() { def run() { pushAndNotifyBlock(newBlock) } }) } } @@ -175,7 +175,7 @@ extends Thread with Logging { try{ if (blockManager != null) { val startTime = System.currentTimeMillis - logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.asInstanceOf[LongTime].milliseconds) + " ms") + logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.milliseconds) + " ms") /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY)*/ /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_2)*/ blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY_2) @@ -343,7 +343,7 @@ extends Thread with Logging { def waitFor(time: Time) { val currentTimeMillis = System.currentTimeMillis - val targetTimeMillis = time.asInstanceOf[LongTime].milliseconds + val targetTimeMillis = time.milliseconds if (currentTimeMillis < targetTimeMillis) { val sleepTime = (targetTimeMillis - currentTimeMillis) Thread.sleep(sleepTime + 1) @@ -353,7 +353,7 @@ extends Thread with Logging { def notifyScheduler(interval: Interval, blockIds: Array[String]) { try { sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) - val time = interval.endTime.asInstanceOf[LongTime] + val time = interval.endTime val delay = (System.currentTimeMillis - time.milliseconds) logInfo("Notification delay for " + time + " is " + delay + " ms") } catch { diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index b932fe9258..c4573137ae 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -1,19 +1,34 @@ package spark.streaming -abstract case class Time { +class Time(private var millis: Long) { - // basic operations that must be overridden - def copy(): Time - def zero: Time - def < (that: Time): Boolean - def += (that: Time): Time - def -= (that: Time): Time - def floor(that: Time): Time - def isMultipleOf(that: Time): Boolean + def copy() = new Time(this.millis) + + def zero = Time.zero + + def < (that: Time): Boolean = + (this.millis < that.millis) + + def <= (that: Time) = (this < that || this == that) + + def > (that: Time) = !(this <= that) + + def >= (that: Time) = !(this < that) + + def += (that: Time): Time = { + this.millis += that.millis + this + } + + def -= (that: Time): Time = { + this.millis -= that.millis + this + } - // derived operations composed of basic operations def + (that: Time) = this.copy() += that + def - (that: Time) = this.copy() -= that + def * (times: Int) = { var count = 0 var result = this.copy() @@ -23,63 +38,44 @@ abstract case class Time { } result } - def <= (that: Time) = (this < that || this == that) - def > (that: Time) = !(this <= that) - def >= (that: Time) = !(this < that) - def isZero = (this == zero) - def toFormattedString = toString -} - -object Time { - def Milliseconds(milliseconds: Long) = LongTime(milliseconds) - - def zero = LongTime(0) -} - -case class LongTime(var milliseconds: Long) extends Time { - - override def copy() = LongTime(this.milliseconds) - - override def zero = LongTime(0) - - override def < (that: Time): Boolean = - (this.milliseconds < that.asInstanceOf[LongTime].milliseconds) - - override def += (that: Time): Time = { - this.milliseconds += that.asInstanceOf[LongTime].milliseconds - this - } - override def -= (that: Time): Time = { - this.milliseconds -= that.asInstanceOf[LongTime].milliseconds - this + def floor(that: Time): Time = { + val t = that.millis + val m = math.floor(this.millis / t).toLong + new Time(m * t) } - override def floor(that: Time): Time = { - val t = that.asInstanceOf[LongTime].milliseconds - val m = this.milliseconds / t - LongTime(m.toLong * t) - } + def isMultipleOf(that: Time): Boolean = + (this.millis % that.millis == 0) - override def isMultipleOf(that: Time): Boolean = - (this.milliseconds % that.asInstanceOf[LongTime].milliseconds == 0) + def isZero = (this.millis == 0) - override def isZero = (this.milliseconds == 0) + override def toString() = (millis.toString + " ms") - override def toString = (milliseconds.toString + "ms") + def toFormattedString() = millis.toString + + def milliseconds() = millis +} - override def toFormattedString = milliseconds.toString +object Time { + val zero = new Time(0) + + def apply(milliseconds: Long) = new Time(milliseconds) + + implicit def toTime(long: Long) = Time(long) + + implicit def toLong(time: Time) = time.milliseconds } object Milliseconds { - def apply(milliseconds: Long) = LongTime(milliseconds) + def apply(milliseconds: Long) = Time(milliseconds) } object Seconds { - def apply(seconds: Long) = LongTime(seconds * 1000) + def apply(seconds: Long) = Time(seconds * 1000) } object Minutes { - def apply(minutes: Long) = LongTime(minutes * 60000) + def apply(minutes: Long) = Time(minutes * 60000) } diff --git a/streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala deleted file mode 100644 index 2ca72da79f..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala +++ /dev/null @@ -1,138 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.JavaConversions.mapAsScalaMap -import scala.collection.mutable.Queue - -import java.lang.{Long => JLong} - -object DumbTopKWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - wordCounts.persist(StorageLevel.MEMORY_ONLY) - val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) - - def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { - val taken = new Array[(String, JLong)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, JLong) = null - var swap: (String, JLong) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - /*println("count = " + count)*/ - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 10 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala deleted file mode 100644 index 34e7edfda9..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala +++ /dev/null @@ -1,92 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.collection.JavaConversions.mapAsScalaMap - -import java.lang.{Long => JLong} -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - -object DumbWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - - map.toIterator - } - - val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - wordCounts.persist(StorageLevel.MEMORY_ONLY) - val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala b/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala new file mode 100644 index 0000000000..d56fdcdf29 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala @@ -0,0 +1,41 @@ +package spark.streaming.examples + +import spark.RDD +import spark.streaming.SparkStreamContext +import spark.streaming.SparkStreamContext._ +import spark.streaming.Seconds + +import scala.collection.mutable.SynchronizedQueue + +object ExampleOne { + + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: ExampleOne ") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new SparkStreamContext(args(0), "ExampleOne") + ssc.setBatchDuration(Seconds(1)) + + // Create the queue through which RDDs can be pushed to + // a QueueInputRDS + val rddQueue = new SynchronizedQueue[RDD[Int]]() + + // Create the QueueInputRDs and use it do some processing + val inputStream = ssc.createQueueStream(rddQueue) + val mappedStream = inputStream.map(x => (x % 10, 1)) + val reducedStream = mappedStream.reduceByKey(_ + _) + reducedStream.print() + ssc.start() + + // Create and push some RDDs into + for (i <- 1 to 30) { + rddQueue += ssc.sc.makeRDD(1 to 1000, 10) + Thread.sleep(1000) + } + ssc.stop() + System.exit(0) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala b/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala new file mode 100644 index 0000000000..4b8f6d609d --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala @@ -0,0 +1,47 @@ +package spark.streaming.examples + +import spark.streaming.SparkStreamContext +import spark.streaming.SparkStreamContext._ +import spark.streaming.Seconds +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + + +object ExampleTwo { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: ExampleOne ") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new SparkStreamContext(args(0), "ExampleTwo") + ssc.setBatchDuration(Seconds(2)) + + // Create the new directory + val directory = new Path(args(1)) + val fs = directory.getFileSystem(new Configuration()) + if (fs.exists(directory)) throw new Exception("This directory already exists") + fs.mkdirs(directory) + + // Create the FileInputRDS on the directory and use the + // stream to count words in new files created + val inputRDS = ssc.createTextFileStream(directory.toString) + val wordsRDS = inputRDS.flatMap(_.split(" ")) + val wordCountsRDS = wordsRDS.map(x => (x, 1)).reduceByKey(_ + _) + wordCountsRDS.print + ssc.start() + + // Creating new files in the directory + val text = "This is a text file" + for (i <- 1 to 30) { + ssc.sc.parallelize((1 to (i * 10)).map(_ => text), 10) + .saveAsTextFile(new Path(directory, i.toString).toString) + Thread.sleep(1000) + } + Thread.sleep(5000) // Waiting for the file to be processed + ssc.stop() + fs.delete(directory) + System.exit(0) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepCount.scala b/streaming/src/main/scala/spark/streaming/examples/GrepCount.scala deleted file mode 100644 index ec3e70f258..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/GrepCount.scala +++ /dev/null @@ -1,39 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object GrepCount { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: GrepCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "GrepCount") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - val matching = sentences.filter(_.contains("light")) - matching.foreachRDD(rdd => println(rdd.count)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala b/streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala deleted file mode 100644 index 27ecced1c0..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala +++ /dev/null @@ -1,113 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkEnv -import spark.SparkContext -import spark.storage.StorageLevel -import spark.network.Message -import spark.network.ConnectionManagerId - -import java.nio.ByteBuffer - -object GrepCount2 { - - def startSparkEnvs(sc: SparkContext) { - - val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) - sc.runJob(dummy, (_: Iterator[Int]) => {}) - - println("SparkEnvs started") - Thread.sleep(1000) - /*sc.runJob(sc.parallelize(0 to 1000, 100), (_: Iterator[Int]) => {})*/ - } - - def warmConnectionManagers(sc: SparkContext) { - val slaveConnManagerIds = sc.parallelize(0 to 100, 100).map( - i => SparkEnv.get.connectionManager.id).collect().distinct - println("\nSlave ConnectionManagerIds") - slaveConnManagerIds.foreach(println) - println - - Thread.sleep(1000) - val numSlaves = slaveConnManagerIds.size - val count = 3 - val size = 5 * 1024 * 1024 - val iterations = (500 * 1024 * 1024 / (numSlaves * size)).toInt - println("count = " + count + ", size = " + size + ", iterations = " + iterations) - - (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until numSlaves, numSlaves).map(i => { - val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - /*connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - })*/ - - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - val futures = (0 until iterations).map(i => { - slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - println("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") - connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) - }) - }).flatMap(x => x) - val results = futures.map(f => f()) - val finishTime = System.currentTimeMillis - - - val mb = size * results.size / 1024.0 / 1024.0 - val ms = finishTime - startTime - - val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" - println(resultStr) - System.gc() - resultStr - }).collect() - - println("---------------------") - println("Run " + i) - resultStrs.foreach(println) - println("---------------------") - }) - } - - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: GrepCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "GrepCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - /*startSparkEnvs(ssc.sc)*/ - warmConnectionManagers(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-"+i, 500)).toArray - ) - - val matching = sentences.filter(_.contains("light")) - matching.foreachRDD(rdd => println(rdd.count)) - - ssc.run - } -} - - - - diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala b/streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala deleted file mode 100644 index f9674136fe..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala +++ /dev/null @@ -1,54 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object GrepCountApprox { - var inputFile : String = null - var hdfs : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 5) { - println ("Usage: GrepCountApprox ") - System.exit(1) - } - - hdfs = args(1) - inputFile = hdfs + args(2) - idealPartitions = args(3).toInt - val timeout = args(4).toLong - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "GrepCount") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - ssc.setTempDir(hdfs + "/tmp") - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - val matching = sentences.filter(_.contains("light")) - var i = 0 - val startTime = System.currentTimeMillis - matching.foreachRDD { rdd => - val myNum = i - val result = rdd.countApprox(timeout) - val initialTime = (System.currentTimeMillis - startTime) / 1000.0 - printf("APPROX\t%.2f\t%d\tinitial\t%.1f\t%.1f\n", initialTime, myNum, result.initialValue.mean, - result.initialValue.high - result.initialValue.low) - result.onComplete { r => - val finalTime = (System.currentTimeMillis - startTime) / 1000.0 - printf("APPROX\t%.2f\t%d\tfinal\t%.1f\t0.0\t%.1f\n", finalTime, myNum, r.mean, finalTime - initialTime) - } - i += 1 - } - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala deleted file mode 100644 index a75ccd3a56..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala +++ /dev/null @@ -1,30 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -object SimpleWordCount { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1) - counts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala deleted file mode 100644 index 9672e64b13..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala +++ /dev/null @@ -1,51 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import scala.util.Sorting - -object SimpleWordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SimpleWordCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - - val words = sentences.flatMap(_.split(" ")) - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10) - counts.foreachRDD(_.collect()) - /*words.foreachRDD(_.countByValue())*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala deleted file mode 100644 index 503033a8e5..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala +++ /dev/null @@ -1,83 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import scala.collection.JavaConversions.mapAsScalaMap -import scala.util.Sorting -import java.lang.{Long => JLong} - -object SimpleWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SimpleWordCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 400)).toArray - ) - - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - /*val words = sentences.flatMap(_.split(" "))*/ - /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10)*/ - val counts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - counts.foreachRDD(_.collect()) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala b/streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala deleted file mode 100644 index 031e989c87..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala +++ /dev/null @@ -1,97 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object TopContentCount { - - case class Event(val country: String, val content: String) - - object Event { - def create(string: String): Event = { - val parts = string.split(":") - new Event(parts(0), parts(1)) - } - } - - def main(args: Array[String]) { - - if (args.length < 2) { - println ("Usage: GrepCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "TopContentCount") - val sc = ssc.sc - val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) - sc.runJob(dummy, (_: Iterator[Int]) => {}) - - - val numEventStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - val eventStrings = new UnifiedRDS( - (1 to numEventStreams).map(i => ssc.readTestStream("Events-" + i, 1000)).toArray - ) - - def parse(string: String) = { - val parts = string.split(":") - (parts(0), parts(1)) - } - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val events = eventStrings.map(x => parse(x)) - /*events.print*/ - - val parallelism = 8 - val counts_per_content_per_country = events - .map(x => (x, 1)) - .reduceByKey(_ + _) - /*.reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), parallelism)*/ - /*counts_per_content_per_country.print*/ - - /* - counts_per_content_per_country.persist( - StorageLevel.MEMORY_ONLY_DESER, - StorageLevel.MEMORY_ONLY_DESER_2, - Seconds(1) - )*/ - - val counts_per_country = counts_per_content_per_country - .map(x => (x._1._1, (x._1._2, x._2))) - .groupByKey() - counts_per_country.print - - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - val taken = array.take(k) - taken - } - - val k = 10 - val topKContents_per_country = counts_per_country - .map(x => (x._1, topK(x._2, k))) - .map(x => (x._1, x._2.map(_.toString).reduceLeft(_ + ", " + _))) - - topKContents_per_country.print - - ssc.run - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala deleted file mode 100644 index 679ed0a7ef..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala +++ /dev/null @@ -1,103 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object TopKWordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - - def topK(data: Iterator[(String, Int)], k: Int): Iterator[(String, Int)] = { - val taken = new Array[(String, Int)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, Int) = null - var swap: (String, Int) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - println("count = " + count) - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 10 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala deleted file mode 100644 index c873fbd0f0..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala +++ /dev/null @@ -1,142 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.JavaConversions.mapAsScalaMap -import scala.collection.mutable.Queue - -import java.lang.{Long => JLong} - -object TopKWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "TopKWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray - ) - - /*val words = sentences.flatMap(_.split(" "))*/ - - /*def add(v1: Int, v2: Int) = (v1 + v2) */ - /*def subtract(v1: Int, v2: Int) = (v1 - v2) */ - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - val windowedCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Milliseconds(500), 10) - /*windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1))*/ - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) - - def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { - val taken = new Array[(String, JLong)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, JLong) = null - var swap: (String, JLong) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - println("count = " + count) - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 50 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount.scala deleted file mode 100644 index fb5508ffcc..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala +++ /dev/null @@ -1,62 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordCount { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), - // System.getProperty("spark.default.parallelism", "1").toInt) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) - //windowedCounts.print - - val parallelism = System.getProperty("spark.default.parallelism", "1").toInt - - //val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) - //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) - //val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(_ + _) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), - parallelism) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(10)) - - //windowedCounts.print - windowedCounts.register - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount1.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount1.scala deleted file mode 100644 index 42d985920a..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount1.scala +++ /dev/null @@ -1,46 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordCount1 { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala deleted file mode 100644 index 9168a2fe2f..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ /dev/null @@ -1,55 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object WordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 6) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala deleted file mode 100644 index 1920915af7..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala +++ /dev/null @@ -1,94 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.collection.JavaConversions.mapAsScalaMap - -import java.lang.{Long => JLong} -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - - -object WordCount2_ExtraFunctions { - - def add(v1: JLong, v2: JLong) = (v1 + v2) - - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } -} - -object WordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 40).foreach {i => - sc.parallelize(1 to 20000000, 1000) - .map(_ % 1331).map(_.toString) - .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray - ) - - val windowedCounts = sentences - .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions) - .reduceByKeyAndWindow(WordCount2_ExtraFunctions.add _, WordCount2_ExtraFunctions.subtract _, Seconds(10), Milliseconds(500), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount3.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount3.scala deleted file mode 100644 index 018c19a509..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount3.scala +++ /dev/null @@ -1,49 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -object WordCount3 { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - /*val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1)*/ - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), 1) - /*windowedCounts.print */ - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - array.take(k) - } - - val k = 10 - val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) - topKWindowedCounts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala deleted file mode 100644 index 82b9fa781d..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala +++ /dev/null @@ -1,41 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ -import spark.SparkContext - -object WordCountEc2 { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: SparkStreamContext ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val ssc = new SparkStreamContext(args(0), "Test") - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.foreach(println)*/ - - val words = sentences.flatMap(_.split(" ")) - /*words.foreach(println)*/ - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _) - /*counts.foreach(println)*/ - - counts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - /*counts.register*/ - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala deleted file mode 100644 index 114dd144f1..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala +++ /dev/null @@ -1,51 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -object WordCountTrivialWindow { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCountTrivialWindow") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1)*/ - /*counts.print*/ - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1) - /*windowedCounts.print */ - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - array.take(k) - } - - val k = 10 - val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) - topKWindowedCounts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax.scala deleted file mode 100644 index fbfc48030f..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordMax.scala +++ /dev/null @@ -1,64 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordMax { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - def max(v1: Int, v2: Int) = (if (v1 > v2) v1 else v2) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), - // System.getProperty("spark.default.parallelism", "1").toInt) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) - //windowedCounts.print - - val parallelism = System.getProperty("spark.default.parallelism", "1").toInt - - val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) - //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER) - localCounts.persist(StorageLevel.MEMORY_ONLY_DESER_2) - val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(max _) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), - // parallelism) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) - - //windowedCounts.print - windowedCounts.register - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala new file mode 100644 index 0000000000..6125bb82eb --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -0,0 +1,52 @@ +package spark.streaming.util + +class RecurringTimer(period: Long, callback: (Long) => Unit) { + + val minPollTime = 25L + + val pollTime = { + if (period / 10.0 > minPollTime) { + (period / 10.0).toLong + } else { + minPollTime + } + } + + val thread = new Thread() { + override def run() { loop } + } + + var nextTime = 0L + + def start(): Long = { + nextTime = (math.floor(System.currentTimeMillis() / period) + 1).toLong * period + thread.start() + nextTime + } + + def stop() { + thread.interrupt() + } + + def loop() { + try { + while (true) { + val beforeSleepTime = System.currentTimeMillis() + while (beforeSleepTime >= nextTime) { + callback(nextTime) + nextTime += period + } + val sleepTime = if (nextTime - beforeSleepTime < 2 * pollTime) { + nextTime - beforeSleepTime + } else { + pollTime + } + Thread.sleep(sleepTime) + val afterSleepTime = System.currentTimeMillis() + } + } catch { + case e: InterruptedException => + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala new file mode 100644 index 0000000000..9925b1d07c --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala @@ -0,0 +1,64 @@ +package spark.streaming.util + +import java.net.{Socket, ServerSocket} +import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream} + +object Receiver { + def main(args: Array[String]) { + val port = args(0).toInt + val lsocket = new ServerSocket(port) + println("Listening on port " + port ) + while(true) { + val socket = lsocket.accept() + (new Thread() { + override def run() { + val buffer = new Array[Byte](100000) + var count = 0 + val time = System.currentTimeMillis + try { + val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) + var loop = true + var string: String = null + while((string = is.readUTF) != null) { + count += 28 + } + } catch { + case e: Exception => e.printStackTrace + } + val timeTaken = System.currentTimeMillis - time + val tput = (count / 1024.0) / (timeTaken / 1000.0) + println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s") + } + }).start() + } + } + +} + +object Sender { + + def main(args: Array[String]) { + try { + val host = args(0) + val port = args(1).toInt + val size = args(2).toInt + + val byteStream = new ByteArrayOutputStream() + val stringDataStream = new DataOutputStream(byteStream) + (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy")) + val bytes = byteStream.toByteArray() + println("Generated array of " + bytes.length + " bytes") + + /*val bytes = new Array[Byte](size)*/ + val socket = new Socket(host, port) + val os = socket.getOutputStream + os.write(bytes) + os.flush + socket.close() + + } catch { + case e: Exception => e.printStackTrace + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala new file mode 100644 index 0000000000..94e8f7a849 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala @@ -0,0 +1,92 @@ +package spark.streaming.util + +import spark._ + +import scala.collection.mutable.ArrayBuffer +import scala.util.Random +import scala.io.Source + +import java.net.InetSocketAddress + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +object SentenceFileGenerator { + + def printUsage () { + println ("Usage: SentenceFileGenerator <# partitions> []") + System.exit(0) + } + + def main (args: Array[String]) { + if (args.length < 4) { + printUsage + } + + val master = args(0) + val fs = new Path(args(1)).getFileSystem(new Configuration()) + val targetDirectory = new Path(args(1)).makeQualified(fs) + val numPartitions = args(2).toInt + val sentenceFile = args(3) + val sentencesPerSecond = { + if (args.length > 4) args(4).toInt + else 10 + } + + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n").toArray + source.close () + println("Read " + lines.length + " lines from file " + sentenceFile) + + val sentences = { + val buffer = ArrayBuffer[String]() + val random = new Random() + var i = 0 + while (i < sentencesPerSecond) { + buffer += lines(random.nextInt(lines.length)) + i += 1 + } + buffer.toArray + } + println("Generated " + sentences.length + " sentences") + + val sc = new SparkContext(master, "SentenceFileGenerator") + val sentencesRDD = sc.parallelize(sentences, numPartitions) + + val tempDirectory = new Path(targetDirectory, "_tmp") + + fs.mkdirs(targetDirectory) + fs.mkdirs(tempDirectory) + + var saveTimeMillis = System.currentTimeMillis + try { + while (true) { + val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis) + val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis) + println("Writing to file " + newDir) + sentencesRDD.saveAsTextFile(tmpNewDir.toString) + fs.rename(tmpNewDir, newDir) + saveTimeMillis += 1000 + val sleepTimeMillis = { + val currentTimeMillis = System.currentTimeMillis + if (saveTimeMillis < currentTimeMillis) { + 0 + } else { + saveTimeMillis - currentTimeMillis + } + } + println("Sleeping for " + sleepTimeMillis + " ms") + Thread.sleep(sleepTimeMillis) + } + } catch { + case e: Exception => + } + } +} + + + + diff --git a/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala b/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala new file mode 100644 index 0000000000..60085f4f88 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala @@ -0,0 +1,23 @@ +package spark.streaming.util + +import spark.SparkContext +import SparkContext._ + +object ShuffleTest { + def main(args: Array[String]) { + + if (args.length < 1) { + println ("Usage: ShuffleTest ") + System.exit(1) + } + + val sc = new spark.SparkContext(args(0), "ShuffleTest") + val rdd = sc.parallelize(1 to 1000, 500).cache + + def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) } + + time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } } + System.exit(0) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/util/Utils.scala b/streaming/src/main/scala/spark/streaming/util/Utils.scala new file mode 100644 index 0000000000..86a729fb49 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/Utils.scala @@ -0,0 +1,9 @@ +package spark.streaming.util + +object Utils { + def time(func: => Unit): Long = { + val t = System.currentTimeMillis + func + (System.currentTimeMillis - t) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala b/streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala deleted file mode 100644 index bb32089ae2..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala +++ /dev/null @@ -1,78 +0,0 @@ -package spark.streaming - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.net.InetSocketAddress - -/*import akka.actor.Actor._*/ -/*import akka.actor.ActorRef*/ - - -object SenGeneratorForPerformanceTest { - - def printUsage () { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - - def main (args: Array[String]) { - if (args.length < 3) { - printUsage - } - - val inputManagerIP = args(0) - val inputManagerPort = args(1).toInt - val sentenceFile = args(2) - val sentencesPerSecond = { - if (args.length > 3) args(3).toInt - else 10 - } - - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close () - - try { - /*val inputManager = remote.actorFor("InputReceiver-Sentences",*/ - /* inputManagerIP, inputManagerPort)*/ - val inputManager = select(Node(inputManagerIP, inputManagerPort), Symbol("InputReceiver-Sentences")) - val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 - val random = new Random () - println ("Sending " + sentencesPerSecond + " sentences per second to " + inputManagerIP + ":" + inputManagerPort) - var lastPrintTime = System.currentTimeMillis() - var count = 0 - - while (true) { - /*if (!inputManager.tryTell (lines (random.nextInt (lines.length))))*/ - /*throw new Exception ("disconnected")*/ -// inputManager ! lines (random.nextInt (lines.length)) - for (i <- 0 to sentencesPerSecond) inputManager ! lines (0) - println(System.currentTimeMillis / 1000 + " s") -/* count += 1 - - if (System.currentTimeMillis - lastPrintTime >= 1000) { - println (count + " sentences sent last second") - count = 0 - lastPrintTime = System.currentTimeMillis - } - - Thread.sleep (sleepBetweenSentences.toLong) -*/ - val currentMs = System.currentTimeMillis / 1000; - Thread.sleep ((currentMs * 1000 + 1000) - System.currentTimeMillis) - } - } catch { - case e: Exception => - /*Thread.sleep (1000)*/ - } - } -} - - - - diff --git a/streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala deleted file mode 100644 index 6af270298a..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala +++ /dev/null @@ -1,63 +0,0 @@ -package spark.streaming -import java.net.{Socket, ServerSocket} -import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream} - -object Receiver { - def main(args: Array[String]) { - val port = args(0).toInt - val lsocket = new ServerSocket(port) - println("Listening on port " + port ) - while(true) { - val socket = lsocket.accept() - (new Thread() { - override def run() { - val buffer = new Array[Byte](100000) - var count = 0 - val time = System.currentTimeMillis - try { - val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) - var loop = true - var string: String = null - while((string = is.readUTF) != null) { - count += 28 - } - } catch { - case e: Exception => e.printStackTrace - } - val timeTaken = System.currentTimeMillis - time - val tput = (count / 1024.0) / (timeTaken / 1000.0) - println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s") - } - }).start() - } - } - -} - -object Sender { - - def main(args: Array[String]) { - try { - val host = args(0) - val port = args(1).toInt - val size = args(2).toInt - - val byteStream = new ByteArrayOutputStream() - val stringDataStream = new DataOutputStream(byteStream) - (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy")) - val bytes = byteStream.toByteArray() - println("Generated array of " + bytes.length + " bytes") - - /*val bytes = new Array[Byte](size)*/ - val socket = new Socket(host, port) - val os = socket.getOutputStream - os.write(bytes) - os.flush - socket.close() - - } catch { - case e: Exception => e.printStackTrace - } - } -} - diff --git a/streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala deleted file mode 100644 index 15858f59e3..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala +++ /dev/null @@ -1,92 +0,0 @@ -package spark.streaming - -import spark._ - -import scala.collection.mutable.ArrayBuffer -import scala.util.Random -import scala.io.Source - -import java.net.InetSocketAddress - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -object SentenceFileGenerator { - - def printUsage () { - println ("Usage: SentenceFileGenerator <# partitions> []") - System.exit(0) - } - - def main (args: Array[String]) { - if (args.length < 4) { - printUsage - } - - val master = args(0) - val fs = new Path(args(1)).getFileSystem(new Configuration()) - val targetDirectory = new Path(args(1)).makeQualified(fs) - val numPartitions = args(2).toInt - val sentenceFile = args(3) - val sentencesPerSecond = { - if (args.length > 4) args(4).toInt - else 10 - } - - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n").toArray - source.close () - println("Read " + lines.length + " lines from file " + sentenceFile) - - val sentences = { - val buffer = ArrayBuffer[String]() - val random = new Random() - var i = 0 - while (i < sentencesPerSecond) { - buffer += lines(random.nextInt(lines.length)) - i += 1 - } - buffer.toArray - } - println("Generated " + sentences.length + " sentences") - - val sc = new SparkContext(master, "SentenceFileGenerator") - val sentencesRDD = sc.parallelize(sentences, numPartitions) - - val tempDirectory = new Path(targetDirectory, "_tmp") - - fs.mkdirs(targetDirectory) - fs.mkdirs(tempDirectory) - - var saveTimeMillis = System.currentTimeMillis - try { - while (true) { - val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis) - val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis) - println("Writing to file " + newDir) - sentencesRDD.saveAsTextFile(tmpNewDir.toString) - fs.rename(tmpNewDir, newDir) - saveTimeMillis += 1000 - val sleepTimeMillis = { - val currentTimeMillis = System.currentTimeMillis - if (saveTimeMillis < currentTimeMillis) { - 0 - } else { - saveTimeMillis - currentTimeMillis - } - } - println("Sleeping for " + sleepTimeMillis + " ms") - Thread.sleep(sleepTimeMillis) - } - } catch { - case e: Exception => - } - } -} - - - - diff --git a/streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala b/streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala deleted file mode 100644 index a9f124d2d7..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala +++ /dev/null @@ -1,103 +0,0 @@ -package spark.streaming - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.net.InetSocketAddress - - -object SentenceGenerator { - - def printUsage { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - - def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 - val random = new Random () - - try { - var lastPrintTime = System.currentTimeMillis() - var count = 0 - while(true) { - streamReceiver ! lines(random.nextInt(lines.length)) - count += 1 - if (System.currentTimeMillis - lastPrintTime >= 1000) { - println (count + " sentences sent last second") - count = 0 - lastPrintTime = System.currentTimeMillis - } - Thread.sleep(sleepBetweenSentences.toLong) - } - } catch { - case e: Exception => - } - } - - def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - try { - val numSentences = if (sentencesPerSecond <= 0) { - lines.length - } else { - sentencesPerSecond - } - var nextSendingTime = System.currentTimeMillis() - val pingInterval = if (System.getenv("INTERVAL") != null) { - System.getenv("INTERVAL").toInt - } else { - 2000 - } - while(true) { - (0 until numSentences).foreach(i => { - streamReceiver ! lines(i % lines.length) - }) - println ("Sent " + numSentences + " sentences") - nextSendingTime += pingInterval - val sleepTime = nextSendingTime - System.currentTimeMillis - if (sleepTime > 0) { - println ("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } - } - } catch { - case e: Exception => - } - } - - def main(args: Array[String]) { - if (args.length < 3) { - printUsage - } - - val generateRandomly = false - - val streamReceiverIP = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 - val sentenceInputName = if (args.length > 4) args(4) else "Sentences" - - println("Sending " + sentencesPerSecond + " sentences per second to " + - streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close () - - val streamReceiver = select( - Node(streamReceiverIP, streamReceiverPort), - Symbol("NetworkStreamReceiver-" + sentenceInputName)) - if (generateRandomly) { - generateRandomSentences(lines, sentencesPerSecond, streamReceiver) - } else { - generateSameSentences(lines, sentencesPerSecond, streamReceiver) - } - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala b/streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala deleted file mode 100644 index 32aa4144a0..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala +++ /dev/null @@ -1,22 +0,0 @@ -package spark.streaming -import spark.SparkContext -import SparkContext._ - -object ShuffleTest { - def main(args: Array[String]) { - - if (args.length < 1) { - println ("Usage: ShuffleTest ") - System.exit(1) - } - - val sc = new spark.SparkContext(args(0), "ShuffleTest") - val rdd = sc.parallelize(1 to 1000, 500).cache - - def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) } - - time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } } - System.exit(0) - } -} - diff --git a/streaming/src/test/scala/spark/streaming/RDSSuite.scala b/streaming/src/test/scala/spark/streaming/RDSSuite.scala new file mode 100644 index 0000000000..f51ea50a5d --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/RDSSuite.scala @@ -0,0 +1,65 @@ +package spark.streaming + +import spark.RDD + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.SynchronizedQueue + +class RDSSuite extends FunSuite with BeforeAndAfter { + + var ssc: SparkStreamContext = null + val batchDurationMillis = 1000 + + def testOp[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: RDS[U] => RDS[V], + expectedOutput: Seq[Seq[V]]) = { + try { + ssc = new SparkStreamContext("local", "test") + ssc.setBatchDuration(Milliseconds(batchDurationMillis)) + + val inputStream = ssc.createQueueStream(input.map(ssc.sc.makeRDD(_, 2)).toIterator) + val outputStream = operation(inputStream) + val outputQueue = outputStream.toQueue + + ssc.start() + Thread.sleep(batchDurationMillis * input.size) + + val output = new ArrayBuffer[Seq[V]]() + while(outputQueue.size > 0) { + val rdd = outputQueue.take() + println("Collecting RDD " + rdd.id + ", " + rdd.getClass().getSimpleName() + ", " + rdd.splits.size) + output += (rdd.collect()) + } + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i).toList === expectedOutput(i).toList) + } + } finally { + ssc.stop() + } + } + + test("basic operations") { + val inputData = Array(1 to 4, 5 to 8, 9 to 12) + + // map + testOp(inputData, (r: RDS[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) + + // flatMap + testOp(inputData, (r: RDS[Int]) => r.flatMap(x => Array(x, x * 2)), + inputData.map(_.flatMap(x => Array(x, x * 2))) + ) + } +} + +object RDSSuite { + def main(args: Array[String]) { + val r = new RDSSuite() + val inputData = Array(1 to 4, 5 to 8, 9 to 12) + r.testOp(inputData, (r: RDS[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) + } +} \ No newline at end of file -- cgit v1.2.3 From 886b39de557b4d5f54f5ca11559fca9799534280 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 10 Aug 2012 01:10:02 -0700 Subject: Add Python API. --- .../main/scala/spark/api/python/PythonRDD.scala | 147 ++++++ pyspark/pyspark/__init__.py | 0 pyspark/pyspark/context.py | 69 +++ pyspark/pyspark/examples/__init__.py | 0 pyspark/pyspark/examples/kmeans.py | 56 +++ pyspark/pyspark/examples/pi.py | 20 + pyspark/pyspark/examples/tc.py | 49 ++ pyspark/pyspark/java_gateway.py | 20 + pyspark/pyspark/join.py | 104 +++++ pyspark/pyspark/rdd.py | 517 +++++++++++++++++++++ pyspark/pyspark/serializers.py | 229 +++++++++ pyspark/pyspark/worker.py | 97 ++++ pyspark/requirements.txt | 9 + python/tc.py | 22 + 14 files changed, 1339 insertions(+) create mode 100644 core/src/main/scala/spark/api/python/PythonRDD.scala create mode 100644 pyspark/pyspark/__init__.py create mode 100644 pyspark/pyspark/context.py create mode 100644 pyspark/pyspark/examples/__init__.py create mode 100644 pyspark/pyspark/examples/kmeans.py create mode 100644 pyspark/pyspark/examples/pi.py create mode 100644 pyspark/pyspark/examples/tc.py create mode 100644 pyspark/pyspark/java_gateway.py create mode 100644 pyspark/pyspark/join.py create mode 100644 pyspark/pyspark/rdd.py create mode 100644 pyspark/pyspark/serializers.py create mode 100644 pyspark/pyspark/worker.py create mode 100644 pyspark/requirements.txt create mode 100644 python/tc.py (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala new file mode 100644 index 0000000000..660ad48afe --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -0,0 +1,147 @@ +package spark.api.python + +import java.io.PrintWriter + +import scala.collection.Map +import scala.collection.JavaConversions._ +import scala.io.Source +import spark._ +import api.java.{JavaPairRDD, JavaRDD} +import scala.Some + +trait PythonRDDBase { + def compute[T](split: Split, envVars: Map[String, String], + command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[String]= { + val currentEnvVars = new ProcessBuilder().environment() + val SPARK_HOME = currentEnvVars.get("SPARK_HOME") + + val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) + // Add the environmental variables to the process. + envVars.foreach { + case (variable, value) => currentEnvVars.put(variable, value) + } + + val proc = pb.start() + val env = SparkEnv.get + + // Start a thread to print the process's stderr to ours + new Thread("stderr reader for " + command) { + override def run() { + for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { + System.err.println(line) + } + } + }.start() + + // Start a thread to feed the process input from our parent's iterator + new Thread("stdin writer for " + command) { + override def run() { + SparkEnv.set(env) + val out = new PrintWriter(proc.getOutputStream) + for (elem <- command) { + out.println(elem) + } + for (elem <- parent.iterator(split)) { + out.println(PythonRDD.pythonDump(elem)) + } + out.close() + } + }.start() + + // Return an iterator that read lines from the process's stdout + val lines: Iterator[String] = Source.fromInputStream(proc.getInputStream).getLines + wrapIterator(lines, proc) + } + + def wrapIterator[T](iter: Iterator[T], proc: Process): Iterator[T] = { + return new Iterator[T] { + def next() = iter.next() + + def hasNext = { + if (iter.hasNext) { + true + } else { + val exitStatus = proc.waitFor() + if (exitStatus != 0) { + throw new Exception("Subprocess exited with status " + exitStatus) + } + false + } + } + } + } +} + +class PythonRDD[T: ClassManifest]( + parent: RDD[T], command: Seq[String], envVars: Map[String, String], + preservePartitoning: Boolean, pythonExec: String) + extends RDD[String](parent.context) with PythonRDDBase { + + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = + this(parent, command, Map(), preservePartitoning, pythonExec) + + // Similar to Runtime.exec(), if we are given a single string, split it into words + // using a standard StringTokenizer (i.e. by spaces) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + + override def splits = parent.splits + + override val dependencies = List(new OneToOneDependency(parent)) + + override val partitioner = if (preservePartitoning) parent.partitioner else None + + override def compute(split: Split): Iterator[String] = + compute(split, envVars, command, parent, pythonExec) + + val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) +} + +class PythonPairRDD[T: ClassManifest] ( + parent: RDD[T], command: Seq[String], envVars: Map[String, String], + preservePartitoning: Boolean, pythonExec: String) + extends RDD[(String, String)](parent.context) with PythonRDDBase { + + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = + this(parent, command, Map(), preservePartitoning, pythonExec) + + // Similar to Runtime.exec(), if we are given a single string, split it into words + // using a standard StringTokenizer (i.e. by spaces) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + + override def splits = parent.splits + + override val dependencies = List(new OneToOneDependency(parent)) + + override val partitioner = if (preservePartitoning) parent.partitioner else None + + override def compute(split: Split): Iterator[(String, String)] = { + compute(split, envVars, command, parent, pythonExec).grouped(2).map { + case Seq(a, b) => (a, b) + case x => throw new Exception("Unexpected value: " + x) + } + } + + val asJavaPairRDD : JavaPairRDD[String, String] = JavaPairRDD.fromRDD(this) +} + +object PythonRDD { + def pythonDump[T](x: T): String = { + if (x.isInstanceOf[scala.Option[_]]) { + val t = x.asInstanceOf[scala.Option[_]] + t match { + case None => "*" + case Some(z) => pythonDump(z) + } + } else if (x.isInstanceOf[scala.Tuple2[_, _]]) { + val t = x.asInstanceOf[scala.Tuple2[_, _]] + "(" + pythonDump(t._1) + "," + pythonDump(t._2) + ")" + } else if (x.isInstanceOf[java.util.List[_]]) { + val objs = asScalaBuffer(x.asInstanceOf[java.util.List[_]]).map(pythonDump) + "[" + objs.mkString("|") + "]" + } else { + x.toString + } + } +} diff --git a/pyspark/pyspark/__init__.py b/pyspark/pyspark/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py new file mode 100644 index 0000000000..587ab12b5f --- /dev/null +++ b/pyspark/pyspark/context.py @@ -0,0 +1,69 @@ +import os +import atexit +from tempfile import NamedTemporaryFile + +from pyspark.java_gateway import launch_gateway +from pyspark.serializers import JSONSerializer, NopSerializer +from pyspark.rdd import RDD, PairRDD + + +class SparkContext(object): + + gateway = launch_gateway() + jvm = gateway.jvm + python_dump = jvm.spark.api.python.PythonRDD.pythonDump + + def __init__(self, master, name, defaultSerializer=JSONSerializer, + defaultParallelism=None, pythonExec='python'): + self.master = master + self.name = name + self._jsc = self.jvm.JavaSparkContext(master, name) + self.defaultSerializer = defaultSerializer + self.defaultParallelism = \ + defaultParallelism or self._jsc.sc().defaultParallelism() + self.pythonExec = pythonExec + + def __del__(self): + if self._jsc: + self._jsc.stop() + + def stop(self): + self._jsc.stop() + self._jsc = None + + def parallelize(self, c, numSlices=None, serializer=None): + serializer = serializer or self.defaultSerializer + numSlices = numSlices or self.defaultParallelism + # Calling the Java parallelize() method with an ArrayList is too slow, + # because it sends O(n) Py4J commands. As an alternative, serialized + # objects are written to a file and loaded through textFile(). + tempFile = NamedTemporaryFile(delete=False) + tempFile.writelines(serializer.dumps(x) + '\n' for x in c) + tempFile.close() + atexit.register(lambda: os.unlink(tempFile.name)) + return self.textFile(tempFile.name, numSlices, serializer) + + def parallelizePairs(self, c, numSlices=None, keySerializer=None, + valSerializer=None): + """ + >>> sc = SparkContext("local", "test") + >>> rdd = sc.parallelizePairs([(1, 2), (3, 4)]) + >>> rdd.collect() + [(1, 2), (3, 4)] + """ + keySerializer = keySerializer or self.defaultSerializer + valSerializer = valSerializer or self.defaultSerializer + numSlices = numSlices or self.defaultParallelism + tempFile = NamedTemporaryFile(delete=False) + for (k, v) in c: + tempFile.write(keySerializer.dumps(k).rstrip('\r\n') + '\n') + tempFile.write(valSerializer.dumps(v).rstrip('\r\n') + '\n') + tempFile.close() + atexit.register(lambda: os.unlink(tempFile.name)) + jrdd = self.textFile(tempFile.name, numSlices)._pipePairs([], "echo") + return PairRDD(jrdd, self, keySerializer, valSerializer) + + def textFile(self, name, numSlices=None, serializer=NopSerializer): + numSlices = numSlices or self.defaultParallelism + jrdd = self._jsc.textFile(name, numSlices) + return RDD(jrdd, self, serializer) diff --git a/pyspark/pyspark/examples/__init__.py b/pyspark/pyspark/examples/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyspark/pyspark/examples/kmeans.py b/pyspark/pyspark/examples/kmeans.py new file mode 100644 index 0000000000..0761d6e395 --- /dev/null +++ b/pyspark/pyspark/examples/kmeans.py @@ -0,0 +1,56 @@ +import sys + +from pyspark.context import SparkContext + + +def parseVector(line): + return [float(x) for x in line.split(' ')] + + +def addVec(x, y): + return [a + b for (a, b) in zip(x, y)] + + +def squaredDist(x, y): + return sum((a - b) ** 2 for (a, b) in zip(x, y)) + + +def closestPoint(p, centers): + bestIndex = 0 + closest = float("+inf") + for i in range(len(centers)): + tempDist = squaredDist(p, centers[i]) + if tempDist < closest: + closest = tempDist + bestIndex = i + return bestIndex + + +if __name__ == "__main__": + if len(sys.argv) < 5: + print >> sys.stderr, \ + "Usage: PythonKMeans " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + lines = sc.textFile(sys.argv[2]) + data = lines.map(parseVector).cache() + K = int(sys.argv[3]) + convergeDist = float(sys.argv[4]) + + kPoints = data.takeSample(False, K, 34) + tempDist = 1.0 + + while tempDist > convergeDist: + closest = data.mapPairs( + lambda p : (closestPoint(p, kPoints), (p, 1))) + pointStats = closest.reduceByKey( + lambda (x1, y1), (x2, y2): (addVec(x1, x2), y1 + y2)) + newPoints = pointStats.mapPairs( + lambda (x, (y, z)): (x, [a / z for a in y])).collect() + + tempDist = sum(squaredDist(kPoints[x], y) for (x, y) in newPoints) + + for (x, y) in newPoints: + kPoints[x] = y + + print "Final centers: " + str(kPoints) diff --git a/pyspark/pyspark/examples/pi.py b/pyspark/pyspark/examples/pi.py new file mode 100644 index 0000000000..ad77694c41 --- /dev/null +++ b/pyspark/pyspark/examples/pi.py @@ -0,0 +1,20 @@ +import sys +from random import random +from operator import add +from pyspark.context import SparkContext + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonPi []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + n = 100000 * slices + def f(_): + x = random() * 2 - 1 + y = random() * 2 - 1 + return 1 if x ** 2 + y ** 2 < 1 else 0 + count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) + print "Pi is roughly %f" % (4.0 * count / n) diff --git a/pyspark/pyspark/examples/tc.py b/pyspark/pyspark/examples/tc.py new file mode 100644 index 0000000000..2796fdc6ad --- /dev/null +++ b/pyspark/pyspark/examples/tc.py @@ -0,0 +1,49 @@ +import sys +from random import Random +from pyspark.context import SparkContext + +numEdges = 200 +numVertices = 100 +rand = Random(42) + + +def generateGraph(): + edges = set() + while len(edges) < numEdges: + src = rand.randrange(0, numEdges) + dst = rand.randrange(0, numEdges) + if src != dst: + edges.add((src, dst)) + return edges + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonTC []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + tc = sc.parallelizePairs(generateGraph(), slices).cache() + + # Linear transitive closure: each round grows paths by one edge, + # by joining the graph's edges with the already-discovered paths. + # e.g. join the path (y, z) from the TC with the edge (x, y) from + # the graph to obtain the path (x, z). + + # Because join() joins on keys, the edges are stored in reversed order. + edges = tc.mapPairs(lambda (x, y): (y, x)) + + oldCount = 0L + nextCount = tc.count() + while True: + oldCount = nextCount + # Perform the join, obtaining an RDD of (y, (z, x)) pairs, + # then project the result to obtain the new (x, z) paths. + new_edges = tc.join(edges).mapPairs(lambda (_, (a, b)): (b, a)) + tc = tc.union(new_edges).distinct().cache() + nextCount = tc.count() + if nextCount == oldCount: + break + + print "TC has %i edges" % tc.count() diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py new file mode 100644 index 0000000000..2df80aee85 --- /dev/null +++ b/pyspark/pyspark/java_gateway.py @@ -0,0 +1,20 @@ +import glob +import os +from py4j.java_gateway import java_import, JavaGateway + + +SPARK_HOME = os.environ["SPARK_HOME"] + + +assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \ + "/spark-core-assembly-*-SNAPSHOT.jar")[0] + + +def launch_gateway(): + gateway = JavaGateway.launch_gateway(classpath=assembly_jar, + javaopts=["-Xmx256m"], die_on_exit=True) + java_import(gateway.jvm, "spark.api.java.*") + java_import(gateway.jvm, "spark.api.python.*") + java_import(gateway.jvm, "scala.Tuple2") + java_import(gateway.jvm, "spark.api.python.PythonRDD.pythonDump") + return gateway diff --git a/pyspark/pyspark/join.py b/pyspark/pyspark/join.py new file mode 100644 index 0000000000..c67520fce8 --- /dev/null +++ b/pyspark/pyspark/join.py @@ -0,0 +1,104 @@ +""" +Copyright (c) 2011, Douban Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" +from pyspark.serializers import PairSerializer, OptionSerializer, \ + ArraySerializer + + +def _do_python_join(rdd, other, numSplits, dispatch, valSerializer): + vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) + ws = other.mapPairs(lambda (k, v): (k, (2, v))) + return vs.union(ws).groupByKey(numSplits) \ + .flatMapValues(dispatch, valSerializer) + + +def python_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return [(v, w) for v in vbuf for w in wbuf] + valSerializer = PairSerializer(rdd.valSerializer, other.valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + + +def python_right_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not vbuf: + vbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + valSerializer = PairSerializer(OptionSerializer(rdd.valSerializer), + other.valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + + +def python_left_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not wbuf: + wbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + valSerializer = PairSerializer(rdd.valSerializer, + OptionSerializer(other.valSerializer)) + return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + + +def python_cogroup(rdd, other, numSplits): + resultValSerializer = PairSerializer( + ArraySerializer(rdd.valSerializer), + ArraySerializer(other.valSerializer)) + vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) + ws = other.mapPairs(lambda (k, v): (k, (2, v))) + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return (vbuf, wbuf) + return vs.union(ws).groupByKey(numSplits) \ + .mapValues(dispatch, resultValSerializer) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py new file mode 100644 index 0000000000..c892e86b93 --- /dev/null +++ b/pyspark/pyspark/rdd.py @@ -0,0 +1,517 @@ +from base64 import standard_b64encode as b64enc +from cloud.serialization import cloudpickle +from itertools import chain + +from pyspark.serializers import PairSerializer, NopSerializer, \ + OptionSerializer, ArraySerializer +from pyspark.join import python_join, python_left_outer_join, \ + python_right_outer_join, python_cogroup + + +class RDD(object): + + def __init__(self, jrdd, ctx, serializer=None): + self._jrdd = jrdd + self.is_cached = False + self.ctx = ctx + self.serializer = serializer or ctx.defaultSerializer + + def _builder(self, jrdd, ctx): + return RDD(jrdd, ctx, self.serializer) + + @property + def id(self): + return self._jrdd.id() + + @property + def splits(self): + return self._jrdd.splits() + + @classmethod + def _get_pipe_command(cls, command, functions): + if functions and not isinstance(functions, (list, tuple)): + functions = [functions] + worker_args = [command] + for f in functions: + worker_args.append(b64enc(cloudpickle.dumps(f))) + return " ".join(worker_args) + + def cache(self): + self.is_cached = True + self._jrdd.cache() + return self + + def map(self, f, serializer=None, preservesPartitioning=False): + return MappedRDD(self, f, serializer, preservesPartitioning) + + def mapPairs(self, f, keySerializer=None, valSerializer=None, + preservesPartitioning=False): + return PairMappedRDD(self, f, keySerializer, valSerializer, + preservesPartitioning) + + def flatMap(self, f, serializer=None): + """ + >>> rdd = sc.parallelize([2, 3, 4]) + >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) + [1, 1, 1, 2, 2, 3] + """ + serializer = serializer or self.ctx.defaultSerializer + dumps = serializer.dumps + loads = self.serializer.loads + def func(x): + pickled_elems = (dumps(y) for y in f(loads(x))) + return "\n".join(pickled_elems) or None + pipe_command = RDD._get_pipe_command("map", [func]) + class_manifest = self._jrdd.classManifest() + jrdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, + False, self.ctx.pythonExec, + class_manifest).asJavaRDD() + return RDD(jrdd, self.ctx, serializer) + + def flatMapPairs(self, f, keySerializer=None, valSerializer=None, + preservesPartitioning=False): + """ + >>> rdd = sc.parallelize([2, 3, 4]) + >>> sorted(rdd.flatMapPairs(lambda x: [(x, x), (x, x)]).collect()) + [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] + """ + keySerializer = keySerializer or self.ctx.defaultSerializer + valSerializer = valSerializer or self.ctx.defaultSerializer + dumpk = keySerializer.dumps + dumpv = valSerializer.dumps + loads = self.serializer.loads + def func(x): + pairs = f(loads(x)) + pickled_pairs = ((dumpk(k), dumpv(v)) for (k, v) in pairs) + return "\n".join(chain.from_iterable(pickled_pairs)) or None + pipe_command = RDD._get_pipe_command("map", [func]) + class_manifest = self._jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, + preservesPartitioning, self.ctx.pythonExec, class_manifest) + return PairRDD(python_rdd.asJavaPairRDD(), self.ctx, keySerializer, + valSerializer) + + def filter(self, f): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4, 5]) + >>> rdd.filter(lambda x: x % 2 == 0).collect() + [2, 4] + """ + loads = self.serializer.loads + def filter_func(x): return x if f(loads(x)) else None + return self._builder(self._pipe(filter_func), self.ctx) + + def _pipe(self, functions, command="map"): + class_manifest = self._jrdd.classManifest() + pipe_command = RDD._get_pipe_command(command, functions) + python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, + False, self.ctx.pythonExec, class_manifest) + return python_rdd.asJavaRDD() + + def _pipePairs(self, functions, command="mapPairs", + preservesPartitioning=False): + class_manifest = self._jrdd.classManifest() + pipe_command = RDD._get_pipe_command(command, functions) + python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, + preservesPartitioning, self.ctx.pythonExec, class_manifest) + return python_rdd.asJavaPairRDD() + + def distinct(self): + """ + >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) + [1, 2, 3] + """ + if self.serializer.is_comparable: + return self._builder(self._jrdd.distinct(), self.ctx) + return self.mapPairs(lambda x: (x, "")) \ + .reduceByKey(lambda x, _: x) \ + .map(lambda (x, _): x) + + def sample(self, withReplacement, fraction, seed): + jrdd = self._jrdd.sample(withReplacement, fraction, seed) + return self._builder(jrdd, self.ctx) + + def takeSample(self, withReplacement, num, seed): + vals = self._jrdd.takeSample(withReplacement, num, seed) + return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + + def union(self, other): + """ + >>> rdd = sc.parallelize([1, 1, 2, 3]) + >>> rdd.union(rdd).collect() + [1, 1, 2, 3, 1, 1, 2, 3] + """ + return self._builder(self._jrdd.union(other._jrdd), self.ctx) + + # TODO: sort + + # TODO: Overload __add___? + + # TODO: glom + + def cartesian(self, other): + """ + >>> rdd = sc.parallelize([1, 2]) + >>> sorted(rdd.cartesian(rdd).collect()) + [(1, 1), (1, 2), (2, 1), (2, 2)] + """ + return PairRDD(self._jrdd.cartesian(other._jrdd), self.ctx) + + # numsplits + def groupBy(self, f, numSplits=None): + """ + >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) + >>> sorted(rdd.groupBy(lambda x: x % 2).collect()) + [(0, [2, 8]), (1, [1, 1, 3, 5])] + """ + return self.mapPairs(lambda x: (f(x), x)).groupByKey(numSplits) + + # TODO: pipe + + # TODO: mapPartitions + + def foreach(self, f): + """ + >>> def f(x): print x + >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) + """ + self.map(f).collect() # Force evaluation + + def collect(self): + vals = self._jrdd.collect() + return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + + def reduce(self, f, serializer=None): + """ + >>> import operator + >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(operator.add) + 15 + """ + serializer = serializer or self.ctx.defaultSerializer + loads = self.serializer.loads + dumps = serializer.dumps + def reduceFunction(x, acc): + if acc is None: + return loads(x) + else: + return f(loads(x), acc) + vals = self._pipe([reduceFunction, dumps], command="reduce").collect() + return reduce(f, (serializer.loads(x) for x in vals)) + + # TODO: fold + + # TODO: aggregate + + def count(self): + """ + >>> sc.parallelize([2, 3, 4]).count() + 3L + """ + return self._jrdd.count() + + # TODO: count approx methods + + def take(self, num): + """ + >>> sc.parallelize([2, 3, 4]).take(2) + [2, 3] + """ + vals = self._jrdd.take(num) + return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + + def first(self): + """ + >>> sc.parallelize([2, 3, 4]).first() + 2 + """ + return self.serializer.loads(self.ctx.python_dump(self._jrdd.first())) + + # TODO: saveAsTextFile + + # TODO: saveAsObjectFile + + +class PairRDD(RDD): + + def __init__(self, jrdd, ctx, keySerializer=None, valSerializer=None): + RDD.__init__(self, jrdd, ctx) + self.keySerializer = keySerializer or ctx.defaultSerializer + self.valSerializer = valSerializer or ctx.defaultSerializer + self.serializer = \ + PairSerializer(self.keySerializer, self.valSerializer) + + def _builder(self, jrdd, ctx): + return PairRDD(jrdd, ctx, self.keySerializer, self.valSerializer) + + def reduceByKey(self, func, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(x.reduceByKey(lambda a, b: a + b).collect()) + [('a', 2), ('b', 1)] + """ + return self.combineByKey(lambda x: x, func, func, numSplits) + + # TODO: reduceByKeyLocally() + + # TODO: countByKey() + + # TODO: partitionBy + + def join(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2), ("a", 3)]) + >>> x.join(y).collect() + [('a', (1, 2)), ('a', (1, 3))] + + Check that we get a PairRDD-like object back: + >>> assert x.join(y).join + """ + assert self.keySerializer.name == other.keySerializer.name + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.join(other._jrdd), + self.ctx, self.keySerializer, + PairSerializer(self.valSerializer, other.valSerializer)) + else: + return python_join(self, other, numSplits) + + def leftOuterJoin(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2)]) + >>> sorted(x.leftOuterJoin(y).collect()) + [('a', (1, 2)), ('b', (4, None))] + """ + assert self.keySerializer.name == other.keySerializer.name + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.leftOuterJoin(other._jrdd), + self.ctx, self.keySerializer, + PairSerializer(self.valSerializer, + OptionSerializer(other.valSerializer))) + else: + return python_left_outer_join(self, other, numSplits) + + def rightOuterJoin(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2)]) + >>> sorted(y.rightOuterJoin(x).collect()) + [('a', (2, 1)), ('b', (None, 4))] + """ + assert self.keySerializer.name == other.keySerializer.name + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.rightOuterJoin(other._jrdd), + self.ctx, self.keySerializer, + PairSerializer(OptionSerializer(self.valSerializer), + other.valSerializer)) + else: + return python_right_outer_join(self, other, numSplits) + + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numSplits=None, serializer=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> def f(x): return x + >>> def add(a, b): return a + str(b) + >>> sorted(x.combineByKey(str, add, add).collect()) + [('a', '11'), ('b', '1')] + """ + serializer = serializer or self.ctx.defaultSerializer + if numSplits is None: + numSplits = self.ctx.defaultParallelism + # Use hash() to create keys that are comparable in Java. + loadkv = self.serializer.loads + def pairify(kv): + # TODO: add method to deserialize only the key or value from + # a PairSerializer? + key = loadkv(kv)[0] + return (str(hash(key)), kv) + partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) + jrdd = self._pipePairs(pairify).partitionBy(partitioner) + pairified = PairRDD(jrdd, self.ctx, NopSerializer, self.serializer) + + loads = PairSerializer(NopSerializer, self.serializer).loads + dumpk = self.keySerializer.dumps + dumpc = serializer.dumps + + functions = [createCombiner, mergeValue, mergeCombiners, loads, dumpk, + dumpc] + jpairs = pairified._pipePairs(functions, "combine_by_key", + preservesPartitioning=True) + return PairRDD(jpairs, self.ctx, self.keySerializer, serializer) + + def groupByKey(self, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(x.groupByKey().collect()) + [('a', [1, 1]), ('b', [1])] + """ + + def createCombiner(x): + return [x] + + def mergeValue(xs, x): + xs.append(x) + return xs + + def mergeCombiners(a, b): + return a + b + + return self.combineByKey(createCombiner, mergeValue, mergeCombiners, + numSplits) + + def collectAsMap(self): + """ + >>> m = sc.parallelizePairs([(1, 2), (3, 4)]).collectAsMap() + >>> m[1] + 2 + >>> m[3] + 4 + """ + m = self._jrdd.collectAsMap() + def loads(x): + (k, v) = x + return (self.keySerializer.loads(k), self.valSerializer.loads(v)) + return dict(loads(x) for x in m.items()) + + def flatMapValues(self, f, valSerializer=None): + flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + return self.flatMapPairs(flat_map_fn, self.keySerializer, + valSerializer, True) + + def mapValues(self, f, valSerializer=None): + map_values_fn = lambda (k, v): (k, f(v)) + return self.mapPairs(map_values_fn, self.keySerializer, valSerializer, + True) + + # TODO: support varargs cogroup of several RDDs. + def groupWith(self, other): + return self.cogroup(other) + + def cogroup(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2)]) + >>> x.cogroup(y).collect() + [('a', ([1], [2])), ('b', ([4], []))] + """ + assert self.keySerializer.name == other.keySerializer.name + resultValSerializer = PairSerializer( + ArraySerializer(self.valSerializer), + ArraySerializer(other.valSerializer)) + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.cogroup(other._jrdd), + self.ctx, self.keySerializer, resultValSerializer) + else: + return python_cogroup(self, other, numSplits) + + # TODO: `lookup` is disabled because we can't make direct comparisons based + # on the key; we need to compare the hash of the key to the hash of the + # keys in the pairs. This could be an expensive operation, since those + # hashes aren't retained. + + # TODO: file saving + + +class MappedRDDBase(object): + def __init__(self, prev, func, serializer, preservesPartitioning=False): + if isinstance(prev, MappedRDDBase) and not prev.is_cached: + prev_func = prev.func + self.func = lambda x: func(prev_func(x)) + self.preservesPartitioning = \ + prev.preservesPartitioning and preservesPartitioning + self._prev_jrdd = prev._prev_jrdd + self._prev_serializer = prev._prev_serializer + else: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self._prev_serializer = prev.serializer + self.serializer = serializer or prev.ctx.defaultSerializer + self.is_cached = False + self.ctx = prev.ctx + self.prev = prev + self._jrdd_val = None + + +class MappedRDD(MappedRDDBase, RDD): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4]) + >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + """ + + @property + def _jrdd(self): + if not self._jrdd_val: + udf = self.func + loads = self._prev_serializer.loads + dumps = self.serializer.dumps + func = lambda x: dumps(udf(loads(x))) + pipe_command = RDD._get_pipe_command("map", [func]) + class_manifest = self._prev_jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), + pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + class_manifest) + self._jrdd_val = python_rdd.asJavaRDD() + return self._jrdd_val + + +class PairMappedRDD(MappedRDDBase, PairRDD): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4]) + >>> rdd.mapPairs(lambda x: (x, x)) \\ + ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ + ... .collect() + [(2, 2), (4, 4), (6, 6), (8, 8)] + >>> rdd.mapPairs(lambda x: (x, x)) \\ + ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ + ... .map(lambda (x, _): x).collect() + [2, 4, 6, 8] + """ + + def __init__(self, prev, func, keySerializer=None, valSerializer=None, + preservesPartitioning=False): + self.keySerializer = keySerializer or prev.ctx.defaultSerializer + self.valSerializer = valSerializer or prev.ctx.defaultSerializer + serializer = PairSerializer(self.keySerializer, self.valSerializer) + MappedRDDBase.__init__(self, prev, func, serializer, + preservesPartitioning) + + @property + def _jrdd(self): + if not self._jrdd_val: + udf = self.func + loads = self._prev_serializer.loads + dumpk = self.keySerializer.dumps + dumpv = self.valSerializer.dumps + def func(x): + (k, v) = udf(loads(x)) + return (dumpk(k), dumpv(v)) + pipe_command = RDD._get_pipe_command("mapPairs", [func]) + class_manifest = self._prev_jrdd.classManifest() + self._jrdd_val = self.ctx.jvm.PythonPairRDD(self._prev_jrdd.rdd(), + pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + class_manifest).asJavaPairRDD() + return self._jrdd_val + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.serializers import PickleSerializer, JSONSerializer + globs = globals().copy() + globs['sc'] = SparkContext('local', 'PythonTest', + defaultSerializer=JSONSerializer) + doctest.testmod(globs=globs) + globs['sc'].stop() + globs['sc'] = SparkContext('local', 'PythonTest', + defaultSerializer=PickleSerializer) + doctest.testmod(globs=globs) + globs['sc'].stop() + + +if __name__ == "__main__": + _test() diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py new file mode 100644 index 0000000000..b113f5656b --- /dev/null +++ b/pyspark/pyspark/serializers.py @@ -0,0 +1,229 @@ +""" +Data serialization methods. + +The Spark Python API is built on top of the Spark Java API. RDDs created in +Python are stored in Java as RDDs of Strings. Python objects are automatically +serialized/deserialized, so this representation is transparent to the end-user. + +------------------ +Serializer objects +------------------ + +`Serializer` objects are used to customize how an RDD's values are serialized. + +Each `Serializer` is a named tuple with four fields: + + - A `dumps` function, for serializing a Python object to a string. + + - A `loads` function, for deserializing a Python object from a string. + + - An `is_comparable` field, True if equal Python objects are serialized to + equal strings, and False otherwise. + + - A `name` field, used to identify the Serializer. Serializers are + compared for equality by comparing their names. + +The serializer's output should be base64-encoded. + +------------------------------------------------------------------ +`is_comparable`: comparing serialized representations for equality +------------------------------------------------------------------ + +If `is_comparable` is False, the serializer's representations of equal objects +are not required to be equal: + +>>> import pickle +>>> a = {1: 0, 9: 0} +>>> b = {9: 0, 1: 0} +>>> a == b +True +>>> pickle.dumps(a) == pickle.dumps(b) +False + +RDDs with comparable serializers can use native Java implementations of +operations like join() and distinct(), which may lead to better performance by +eliminating deserialization and Python comparisons. + +The default JSONSerializer produces comparable representations of common Python +data structures. + +-------------------------------------- +Examples of serialized representations +-------------------------------------- + +The RDD transformations that use Python UDFs are implemented in terms of +a modified `PipedRDD.pipe()` function. For each record `x` in the RDD, the +`pipe()` function pipes `x.toString()` to a Python worker process, which +deserializes the string into a Python object, executes user-defined functions, +and outputs serialized Python objects. + +The regular `toString()` method returns an ambiguous representation, due to the +way that Scala `Option` instances are printed: + +>>> from context import SparkContext +>>> sc = SparkContext("local", "SerializerDocs") +>>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) +>>> y = sc.parallelizePairs([("a", 2)]) + +>>> print y.rightOuterJoin(x)._jrdd.first().toString() +(ImEi,(Some(Mg==),MQ==)) + +In Java, preprocessing is performed to handle Option instances, so the Python +process receives unambiguous input: + +>>> print sc.python_dump(y.rightOuterJoin(x)._jrdd.first()) +(ImEi,(Mg==,MQ==)) + +The base64-encoding eliminates the need to escape newlines, parentheses and +other special characters. + +---------------------- +Serializer composition +---------------------- + +In order to handle nested structures, which could contain object serialized +with different serializers, the RDD module composes serializers. For example, +the serializers in the previous example are: + +>>> print x.serializer.name +PairSerializer + +>>> print y.serializer.name +PairSerializer + +>>> print y.rightOuterJoin(x).serializer.name +PairSerializer, JSONSerializer>> +""" +from base64 import standard_b64encode, standard_b64decode +from collections import namedtuple +import cPickle +import simplejson + + +Serializer = namedtuple("Serializer", + ["dumps","loads", "is_comparable", "name"]) + + +NopSerializer = Serializer(str, str, True, "NopSerializer") + + +JSONSerializer = Serializer( + lambda obj: standard_b64encode(simplejson.dumps(obj, sort_keys=True, + separators=(',', ':'))), + lambda s: simplejson.loads(standard_b64decode(s)), + True, + "JSONSerializer" +) + + +PickleSerializer = Serializer( + lambda obj: standard_b64encode(cPickle.dumps(obj)), + lambda s: cPickle.loads(standard_b64decode(s)), + False, + "PickleSerializer" +) + + +def OptionSerializer(serializer): + """ + >>> ser = OptionSerializer(NopSerializer) + >>> ser.loads(ser.dumps("Hello, World!")) + 'Hello, World!' + >>> ser.loads(ser.dumps(None)) is None + True + """ + none_placeholder = '*' + + def dumps(x): + if x is None: + return none_placeholder + else: + return serializer.dumps(x) + + def loads(x): + if x == none_placeholder: + return None + else: + return serializer.loads(x) + + name = "OptionSerializer<%s>" % serializer.name + return Serializer(dumps, loads, serializer.is_comparable, name) + + +def PairSerializer(keySerializer, valSerializer): + """ + Returns a Serializer for a (key, value) pair. + + >>> ser = PairSerializer(JSONSerializer, JSONSerializer) + >>> ser.loads(ser.dumps((1, 2))) + (1, 2) + + >>> ser = PairSerializer(JSONSerializer, ser) + >>> ser.loads(ser.dumps((1, (2, 3)))) + (1, (2, 3)) + """ + def loads(kv): + try: + (key, val) = kv[1:-1].split(',', 1) + key = keySerializer.loads(key) + val = valSerializer.loads(val) + return (key, val) + except: + print "Error in deserializing pair from '%s'" % str(kv) + raise + + def dumps(kv): + (key, val) = kv + return"(%s,%s)" % (keySerializer.dumps(key), valSerializer.dumps(val)) + is_comparable = \ + keySerializer.is_comparable and valSerializer.is_comparable + name = "PairSerializer<%s, %s>" % (keySerializer.name, valSerializer.name) + return Serializer(dumps, loads, is_comparable, name) + + +def ArraySerializer(serializer): + """ + >>> ser = ArraySerializer(JSONSerializer) + >>> ser.loads(ser.dumps([1, 2, 3, 4])) + [1, 2, 3, 4] + >>> ser = ArraySerializer(PairSerializer(JSONSerializer, PickleSerializer)) + >>> ser.loads(ser.dumps([('a', 1), ('b', 2)])) + [('a', 1), ('b', 2)] + >>> ser.loads(ser.dumps([('a', 1)])) + [('a', 1)] + >>> ser.loads(ser.dumps([])) + [] + """ + def dumps(arr): + if arr == []: + return '[]' + else: + return '[' + '|'.join(serializer.dumps(x) for x in arr) + ']' + + def loads(s): + if s == '[]': + return [] + items = s[1:-1] + if '|' in items: + items = items.split('|') + else: + items = [items] + return [serializer.loads(x) for x in items] + + name = "ArraySerializer<%s>" % serializer.name + return Serializer(dumps, loads, serializer.is_comparable, name) + + +# TODO: IntegerSerializer + + +# TODO: DoubleSerializer + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py new file mode 100644 index 0000000000..4d4cc939c3 --- /dev/null +++ b/pyspark/pyspark/worker.py @@ -0,0 +1,97 @@ +""" +Worker that receives input from Piped RDD. +""" +import sys +from base64 import standard_b64decode +# CloudPickler needs to be imported so that depicklers are registered using the +# copy_reg module. +from cloud.serialization.cloudpickle import CloudPickler +import cPickle + + +# Redirect stdout to stderr so that users must return values from functions. +old_stdout = sys.stdout +sys.stdout = sys.stderr + + +def load_function(): + return cPickle.loads(standard_b64decode(sys.stdin.readline().strip())) + + +def output(x): + for line in x.split("\n"): + old_stdout.write(line.rstrip("\r\n") + "\n") + + +def read_input(): + for line in sys.stdin: + yield line.rstrip("\r\n") + + +def do_combine_by_key(): + create_combiner = load_function() + merge_value = load_function() + merge_combiners = load_function() # TODO: not used. + depickler = load_function() + key_pickler = load_function() + combiner_pickler = load_function() + combiners = {} + for line in read_input(): + # Discard the hashcode added in the Python combineByKey() method. + (key, value) = depickler(line)[1] + if key not in combiners: + combiners[key] = create_combiner(value) + else: + combiners[key] = merge_value(combiners[key], value) + for (key, combiner) in combiners.iteritems(): + output(key_pickler(key)) + output(combiner_pickler(combiner)) + + +def do_map(map_pairs=False): + f = load_function() + for line in read_input(): + try: + out = f(line) + if out is not None: + if map_pairs: + for x in out: + output(x) + else: + output(out) + except: + sys.stderr.write("Error processing line '%s'\n" % line) + raise + + +def do_reduce(): + f = load_function() + dumps = load_function() + acc = None + for line in read_input(): + acc = f(line, acc) + output(dumps(acc)) + + +def do_echo(): + old_stdout.writelines(sys.stdin.readlines()) + + +def main(): + command = sys.stdin.readline().strip() + if command == "map": + do_map(map_pairs=False) + elif command == "mapPairs": + do_map(map_pairs=True) + elif command == "combine_by_key": + do_combine_by_key() + elif command == "reduce": + do_reduce() + elif command == "echo": + do_echo() + else: + raise Exception("Unsupported command %s" % command) + + +if __name__ == '__main__': + main() diff --git a/pyspark/requirements.txt b/pyspark/requirements.txt new file mode 100644 index 0000000000..d9b3fe40bd --- /dev/null +++ b/pyspark/requirements.txt @@ -0,0 +1,9 @@ +# The Python API relies on some new features from the Py4J development branch. +# pip can't install Py4J from git because the setup.py file for the Python +# package is not at the root of the git repository. It may be possible to +# install Py4J from git once https://github.com/pypa/pip/pull/526 is merged. + +# git+git://github.com/bartdag/py4j.git@3dbf380d3d2cdeb9aab394454ea74d80c4aba1ea + +simplejson==2.6.1 +cloud==2.5.5 diff --git a/python/tc.py b/python/tc.py new file mode 100644 index 0000000000..5dcc4317e0 --- /dev/null +++ b/python/tc.py @@ -0,0 +1,22 @@ +from rdd import SparkContext + +sc = SparkContext("local", "PythonWordCount") +e = [(1, 2), (2, 3), (4, 1)] + +tc = sc.parallelizePairs(e) + +edges = tc.mapPairs(lambda (x, y): (y, x)) + +oldCount = 0 +nextCount = tc.count() + +def project(x): + return (x[1][1], x[1][0]) + +while nextCount != oldCount: + oldCount = nextCount + tc = tc.union(tc.join(edges).mapPairs(project)).distinct() + nextCount = tc.count() + +print "TC has %i edges" % tc.count() +print tc.collect() -- cgit v1.2.3 From fd94e5443c99775bfad1928729f5075c900ad0f9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Aug 2012 16:07:10 -0700 Subject: Use only cPickle for serialization in Python API. Objects serialized with JSON can be compared for equality, but JSON can be slow to serialize and only supports a limited range of data types. --- .../main/scala/spark/api/python/PythonRDD.scala | 192 +++++++--- pyspark/pyspark/context.py | 49 +-- pyspark/pyspark/java_gateway.py | 1 - pyspark/pyspark/join.py | 32 +- pyspark/pyspark/rdd.py | 414 ++++++++------------- pyspark/pyspark/serializers.py | 233 +----------- pyspark/pyspark/worker.py | 64 ++-- 7 files changed, 381 insertions(+), 604 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 660ad48afe..b9a0168d18 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -1,22 +1,26 @@ package spark.api.python -import java.io.PrintWriter +import java.io._ import scala.collection.Map import scala.collection.JavaConversions._ import scala.io.Source import spark._ -import api.java.{JavaPairRDD, JavaRDD} +import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import scala.{collection, Some} +import collection.parallel.mutable +import scala.collection import scala.Some trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], - command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[String]= { - val currentEnvVars = new ProcessBuilder().environment() - val SPARK_HOME = currentEnvVars.get("SPARK_HOME") + command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[Array[Byte]] = { + val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) // Add the environmental variables to the process. + val currentEnvVars = pb.environment() + envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } @@ -41,33 +45,70 @@ trait PythonRDDBase { for (elem <- command) { out.println(elem) } + out.flush() + val dOut = new DataOutputStream(proc.getOutputStream) for (elem <- parent.iterator(split)) { - out.println(PythonRDD.pythonDump(elem)) + if (elem.isInstanceOf[Array[Byte]]) { + val arr = elem.asInstanceOf[Array[Byte]] + dOut.writeInt(arr.length) + dOut.write(arr) + } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { + val t = elem.asInstanceOf[scala.Tuple2[_, _]] + val t1 = t._1.asInstanceOf[Array[Byte]] + val t2 = t._2.asInstanceOf[Array[Byte]] + val length = t1.length + t2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes + dOut.writeInt(length) + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(PythonRDD.stripPickle(t1)) + dOut.write(PythonRDD.stripPickle(t2)) + dOut.writeByte(Pickle.TUPLE2) + dOut.writeByte(Pickle.STOP) + } else if (elem.isInstanceOf[String]) { + // For uniformity, strings are wrapped into Pickles. + val s = elem.asInstanceOf[String].getBytes("UTF-8") + val length = 2 + 1 + 4 + s.length + 1 + dOut.writeInt(length) + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.writeByte(Pickle.BINUNICODE) + dOut.writeInt(Integer.reverseBytes(s.length)) + dOut.write(s) + dOut.writeByte(Pickle.STOP) + } else { + throw new Exception("Unexpected RDD type") + } } - out.close() + dOut.flush() + out.flush() + proc.getOutputStream.close() } }.start() // Return an iterator that read lines from the process's stdout - val lines: Iterator[String] = Source.fromInputStream(proc.getInputStream).getLines - wrapIterator(lines, proc) - } + val stream = new DataInputStream(proc.getInputStream) + return new Iterator[Array[Byte]] { + def next() = { + val obj = _nextObj + _nextObj = read() + obj + } - def wrapIterator[T](iter: Iterator[T], proc: Process): Iterator[T] = { - return new Iterator[T] { - def next() = iter.next() - - def hasNext = { - if (iter.hasNext) { - true - } else { - val exitStatus = proc.waitFor() - if (exitStatus != 0) { - throw new Exception("Subprocess exited with status " + exitStatus) - } - false + private def read() = { + try { + val length = stream.readInt() + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + } catch { + case eof: EOFException => { new Array[Byte](0) } + case e => throw e } } + + var _nextObj = read() + + def hasNext = _nextObj.length != 0 } } } @@ -75,7 +116,7 @@ trait PythonRDDBase { class PythonRDD[T: ClassManifest]( parent: RDD[T], command: Seq[String], envVars: Map[String, String], preservePartitoning: Boolean, pythonExec: String) - extends RDD[String](parent.context) with PythonRDDBase { + extends RDD[Array[Byte]](parent.context) with PythonRDDBase { def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = this(parent, command, Map(), preservePartitoning, pythonExec) @@ -91,16 +132,16 @@ class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Split): Iterator[String] = + override def compute(split: Split): Iterator[Array[Byte]] = compute(split, envVars, command, parent, pythonExec) - val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) + val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } class PythonPairRDD[T: ClassManifest] ( parent: RDD[T], command: Seq[String], envVars: Map[String, String], preservePartitoning: Boolean, pythonExec: String) - extends RDD[(String, String)](parent.context) with PythonRDDBase { + extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase { def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = this(parent, command, Map(), preservePartitoning, pythonExec) @@ -116,32 +157,95 @@ class PythonPairRDD[T: ClassManifest] ( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Split): Iterator[(String, String)] = { + override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = { compute(split, envVars, command, parent, pythonExec).grouped(2).map { case Seq(a, b) => (a, b) - case x => throw new Exception("Unexpected value: " + x) + case x => throw new Exception("PythonPairRDD: unexpected value: " + x) } } - val asJavaPairRDD : JavaPairRDD[String, String] = JavaPairRDD.fromRDD(this) + val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } + object PythonRDD { - def pythonDump[T](x: T): String = { - if (x.isInstanceOf[scala.Option[_]]) { - val t = x.asInstanceOf[scala.Option[_]] - t match { - case None => "*" - case Some(z) => pythonDump(z) - } - } else if (x.isInstanceOf[scala.Tuple2[_, _]]) { - val t = x.asInstanceOf[scala.Tuple2[_, _]] - "(" + pythonDump(t._1) + "," + pythonDump(t._2) + ")" - } else if (x.isInstanceOf[java.util.List[_]]) { - val objs = asScalaBuffer(x.asInstanceOf[java.util.List[_]]).map(pythonDump) - "[" + objs.mkString("|") + "]" + + /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ + def stripPickle(arr: Array[Byte]) : Array[Byte] = { + arr.slice(2, arr.length - 1) + } + + def asPickle(elem: Any) : Array[Byte] = { + val baos = new ByteArrayOutputStream(); + val dOut = new DataOutputStream(baos); + if (elem.isInstanceOf[Array[Byte]]) { + elem.asInstanceOf[Array[Byte]] + } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { + val t = elem.asInstanceOf[scala.Tuple2[_, _]] + val t1 = t._1.asInstanceOf[Array[Byte]] + val t2 = t._2.asInstanceOf[Array[Byte]] + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(PythonRDD.stripPickle(t1)) + dOut.write(PythonRDD.stripPickle(t2)) + dOut.writeByte(Pickle.TUPLE2) + dOut.writeByte(Pickle.STOP) + baos.toByteArray() + } else if (elem.isInstanceOf[String]) { + // For uniformity, strings are wrapped into Pickles. + val s = elem.asInstanceOf[String].getBytes("UTF-8") + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(Pickle.BINUNICODE) + dOut.writeInt(Integer.reverseBytes(s.length)) + dOut.write(s) + dOut.writeByte(Pickle.STOP) + baos.toByteArray() } else { - x.toString + throw new Exception("Unexpected RDD type") } } + + def pickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + JavaRDD[Array[Byte]] = { + val file = new DataInputStream(new FileInputStream(filename)) + val objs = new collection.mutable.ArrayBuffer[Array[Byte]] + try { + while (true) { + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + objs.append(obj) + } + } catch { + case eof: EOFException => {} + case e => throw e + } + JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) + } + + def arrayAsPickle(arr : Any) : Array[Byte] = { + val pickles : Array[Byte] = arr.asInstanceOf[Array[Any]].map(asPickle).map(stripPickle).flatten + + Array[Byte](Pickle.PROTO, Pickle.TWO, Pickle.EMPTY_LIST, Pickle.MARK) ++ pickles ++ + Array[Byte] (Pickle.APPENDS, Pickle.STOP) + } +} + +private object Pickle { + def b(x: Int): Byte = x.asInstanceOf[Byte] + val PROTO: Byte = b(0x80) + val TWO: Byte = b(0x02) + val BINUNICODE : Byte = 'X' + val STOP : Byte = '.' + val TUPLE2 : Byte = b(0x86) + val EMPTY_LIST : Byte = ']' + val MARK : Byte = '(' + val APPENDS : Byte = 'e' +} +class ExtractValue extends spark.api.java.function.Function[(Array[Byte], + Array[Byte]), Array[Byte]] { + + override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2 + } diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 587ab12b5f..ac7e4057e9 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -3,22 +3,24 @@ import atexit from tempfile import NamedTemporaryFile from pyspark.java_gateway import launch_gateway -from pyspark.serializers import JSONSerializer, NopSerializer -from pyspark.rdd import RDD, PairRDD +from pyspark.serializers import PickleSerializer, dumps +from pyspark.rdd import RDD class SparkContext(object): gateway = launch_gateway() jvm = gateway.jvm - python_dump = jvm.spark.api.python.PythonRDD.pythonDump + pickleFile = jvm.spark.api.python.PythonRDD.pickleFile + asPickle = jvm.spark.api.python.PythonRDD.asPickle + arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle - def __init__(self, master, name, defaultSerializer=JSONSerializer, - defaultParallelism=None, pythonExec='python'): + + def __init__(self, master, name, defaultParallelism=None, + pythonExec='python'): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) - self.defaultSerializer = defaultSerializer self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() self.pythonExec = pythonExec @@ -31,39 +33,26 @@ class SparkContext(object): self._jsc.stop() self._jsc = None - def parallelize(self, c, numSlices=None, serializer=None): - serializer = serializer or self.defaultSerializer - numSlices = numSlices or self.defaultParallelism - # Calling the Java parallelize() method with an ArrayList is too slow, - # because it sends O(n) Py4J commands. As an alternative, serialized - # objects are written to a file and loaded through textFile(). - tempFile = NamedTemporaryFile(delete=False) - tempFile.writelines(serializer.dumps(x) + '\n' for x in c) - tempFile.close() - atexit.register(lambda: os.unlink(tempFile.name)) - return self.textFile(tempFile.name, numSlices, serializer) - - def parallelizePairs(self, c, numSlices=None, keySerializer=None, - valSerializer=None): + def parallelize(self, c, numSlices=None): """ >>> sc = SparkContext("local", "test") - >>> rdd = sc.parallelizePairs([(1, 2), (3, 4)]) + >>> rdd = sc.parallelize([(1, 2), (3, 4)]) >>> rdd.collect() [(1, 2), (3, 4)] """ - keySerializer = keySerializer or self.defaultSerializer - valSerializer = valSerializer or self.defaultSerializer numSlices = numSlices or self.defaultParallelism + # Calling the Java parallelize() method with an ArrayList is too slow, + # because it sends O(n) Py4J commands. As an alternative, serialized + # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) - for (k, v) in c: - tempFile.write(keySerializer.dumps(k).rstrip('\r\n') + '\n') - tempFile.write(valSerializer.dumps(v).rstrip('\r\n') + '\n') + for x in c: + dumps(PickleSerializer.dumps(x), tempFile) tempFile.close() atexit.register(lambda: os.unlink(tempFile.name)) - jrdd = self.textFile(tempFile.name, numSlices)._pipePairs([], "echo") - return PairRDD(jrdd, self, keySerializer, valSerializer) + jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) + return RDD(jrdd, self) - def textFile(self, name, numSlices=None, serializer=NopSerializer): + def textFile(self, name, numSlices=None): numSlices = numSlices or self.defaultParallelism jrdd = self._jsc.textFile(name, numSlices) - return RDD(jrdd, self, serializer) + return RDD(jrdd, self) diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py index 2df80aee85..bcb405ba72 100644 --- a/pyspark/pyspark/java_gateway.py +++ b/pyspark/pyspark/java_gateway.py @@ -16,5 +16,4 @@ def launch_gateway(): java_import(gateway.jvm, "spark.api.java.*") java_import(gateway.jvm, "spark.api.python.*") java_import(gateway.jvm, "scala.Tuple2") - java_import(gateway.jvm, "spark.api.python.PythonRDD.pythonDump") return gateway diff --git a/pyspark/pyspark/join.py b/pyspark/pyspark/join.py index c67520fce8..7036c47980 100644 --- a/pyspark/pyspark/join.py +++ b/pyspark/pyspark/join.py @@ -30,15 +30,12 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ -from pyspark.serializers import PairSerializer, OptionSerializer, \ - ArraySerializer -def _do_python_join(rdd, other, numSplits, dispatch, valSerializer): - vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) - ws = other.mapPairs(lambda (k, v): (k, (2, v))) - return vs.union(ws).groupByKey(numSplits) \ - .flatMapValues(dispatch, valSerializer) +def _do_python_join(rdd, other, numSplits, dispatch): + vs = rdd.map(lambda (k, v): (k, (1, v))) + ws = other.map(lambda (k, v): (k, (2, v))) + return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch) def python_join(rdd, other, numSplits): @@ -50,8 +47,7 @@ def python_join(rdd, other, numSplits): elif n == 2: wbuf.append(v) return [(v, w) for v in vbuf for w in wbuf] - valSerializer = PairSerializer(rdd.valSerializer, other.valSerializer) - return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch) def python_right_outer_join(rdd, other, numSplits): @@ -65,9 +61,7 @@ def python_right_outer_join(rdd, other, numSplits): if not vbuf: vbuf.append(None) return [(v, w) for v in vbuf for w in wbuf] - valSerializer = PairSerializer(OptionSerializer(rdd.valSerializer), - other.valSerializer) - return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch) def python_left_outer_join(rdd, other, numSplits): @@ -81,17 +75,12 @@ def python_left_outer_join(rdd, other, numSplits): if not wbuf: wbuf.append(None) return [(v, w) for v in vbuf for w in wbuf] - valSerializer = PairSerializer(rdd.valSerializer, - OptionSerializer(other.valSerializer)) - return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch) def python_cogroup(rdd, other, numSplits): - resultValSerializer = PairSerializer( - ArraySerializer(rdd.valSerializer), - ArraySerializer(other.valSerializer)) - vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) - ws = other.mapPairs(lambda (k, v): (k, (2, v))) + vs = rdd.map(lambda (k, v): (k, (1, v))) + ws = other.map(lambda (k, v): (k, (2, v))) def dispatch(seq): vbuf, wbuf = [], [] for (n, v) in seq: @@ -100,5 +89,4 @@ def python_cogroup(rdd, other, numSplits): elif n == 2: wbuf.append(v) return (vbuf, wbuf) - return vs.union(ws).groupByKey(numSplits) \ - .mapValues(dispatch, resultValSerializer) + return vs.union(ws).groupByKey(numSplits).mapValues(dispatch) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 5579c56de3..8eccddc0a2 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,31 +1,17 @@ from base64 import standard_b64encode as b64enc -from pyspark import cloudpickle -from itertools import chain -from pyspark.serializers import PairSerializer, NopSerializer, \ - OptionSerializer, ArraySerializer +from pyspark import cloudpickle +from pyspark.serializers import PickleSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup class RDD(object): - def __init__(self, jrdd, ctx, serializer=None): + def __init__(self, jrdd, ctx): self._jrdd = jrdd self.is_cached = False self.ctx = ctx - self.serializer = serializer or ctx.defaultSerializer - - def _builder(self, jrdd, ctx): - return RDD(jrdd, ctx, self.serializer) - - @property - def id(self): - return self._jrdd.id() - - @property - def splits(self): - return self._jrdd.splits() @classmethod def _get_pipe_command(cls, command, functions): @@ -41,55 +27,18 @@ class RDD(object): self._jrdd.cache() return self - def map(self, f, serializer=None, preservesPartitioning=False): - return MappedRDD(self, f, serializer, preservesPartitioning) - - def mapPairs(self, f, keySerializer=None, valSerializer=None, - preservesPartitioning=False): - return PairMappedRDD(self, f, keySerializer, valSerializer, - preservesPartitioning) + def map(self, f, preservesPartitioning=False): + return MappedRDD(self, f, preservesPartitioning) - def flatMap(self, f, serializer=None): + def flatMap(self, f): """ >>> rdd = sc.parallelize([2, 3, 4]) >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) [1, 1, 1, 2, 2, 3] - """ - serializer = serializer or self.ctx.defaultSerializer - dumps = serializer.dumps - loads = self.serializer.loads - def func(x): - pickled_elems = (dumps(y) for y in f(loads(x))) - return "\n".join(pickled_elems) or None - pipe_command = RDD._get_pipe_command("map", [func]) - class_manifest = self._jrdd.classManifest() - jrdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, - False, self.ctx.pythonExec, - class_manifest).asJavaRDD() - return RDD(jrdd, self.ctx, serializer) - - def flatMapPairs(self, f, keySerializer=None, valSerializer=None, - preservesPartitioning=False): - """ - >>> rdd = sc.parallelize([2, 3, 4]) - >>> sorted(rdd.flatMapPairs(lambda x: [(x, x), (x, x)]).collect()) + >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ - keySerializer = keySerializer or self.ctx.defaultSerializer - valSerializer = valSerializer or self.ctx.defaultSerializer - dumpk = keySerializer.dumps - dumpv = valSerializer.dumps - loads = self.serializer.loads - def func(x): - pairs = f(loads(x)) - pickled_pairs = ((dumpk(k), dumpv(v)) for (k, v) in pairs) - return "\n".join(chain.from_iterable(pickled_pairs)) or None - pipe_command = RDD._get_pipe_command("map", [func]) - class_manifest = self._jrdd.classManifest() - python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, - preservesPartitioning, self.ctx.pythonExec, class_manifest) - return PairRDD(python_rdd.asJavaPairRDD(), self.ctx, keySerializer, - valSerializer) + return MappedRDD(self, f, preservesPartitioning=False, command='flatmap') def filter(self, f): """ @@ -97,9 +46,8 @@ class RDD(object): >>> rdd.filter(lambda x: x % 2 == 0).collect() [2, 4] """ - loads = self.serializer.loads - def filter_func(x): return x if f(loads(x)) else None - return self._builder(self._pipe(filter_func), self.ctx) + def filter_func(x): return x if f(x) else None + return RDD(self._pipe(filter_func), self.ctx) def _pipe(self, functions, command="map"): class_manifest = self._jrdd.classManifest() @@ -108,32 +56,22 @@ class RDD(object): False, self.ctx.pythonExec, class_manifest) return python_rdd.asJavaRDD() - def _pipePairs(self, functions, command="mapPairs", - preservesPartitioning=False): - class_manifest = self._jrdd.classManifest() - pipe_command = RDD._get_pipe_command(command, functions) - python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, - preservesPartitioning, self.ctx.pythonExec, class_manifest) - return python_rdd.asJavaPairRDD() - def distinct(self): """ >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) [1, 2, 3] """ - if self.serializer.is_comparable: - return self._builder(self._jrdd.distinct(), self.ctx) - return self.mapPairs(lambda x: (x, "")) \ + return self.map(lambda x: (x, "")) \ .reduceByKey(lambda x, _: x) \ .map(lambda (x, _): x) def sample(self, withReplacement, fraction, seed): jrdd = self._jrdd.sample(withReplacement, fraction, seed) - return self._builder(jrdd, self.ctx) + return RDD(jrdd, self.ctx) def takeSample(self, withReplacement, num, seed): vals = self._jrdd.takeSample(withReplacement, num, seed) - return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + return [PickleSerializer.loads(x) for x in vals] def union(self, other): """ @@ -141,7 +79,7 @@ class RDD(object): >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] """ - return self._builder(self._jrdd.union(other._jrdd), self.ctx) + return RDD(self._jrdd.union(other._jrdd), self.ctx) # TODO: sort @@ -155,16 +93,17 @@ class RDD(object): >>> sorted(rdd.cartesian(rdd).collect()) [(1, 1), (1, 2), (2, 1), (2, 2)] """ - return PairRDD(self._jrdd.cartesian(other._jrdd), self.ctx) + return RDD(self._jrdd.cartesian(other._jrdd), self.ctx) # numsplits def groupBy(self, f, numSplits=None): """ >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) - >>> sorted(rdd.groupBy(lambda x: x % 2).collect()) + >>> result = rdd.groupBy(lambda x: x % 2).collect() + >>> sorted([(x, sorted(y)) for (x, y) in result]) [(0, [2, 8]), (1, [1, 1, 3, 5])] """ - return self.mapPairs(lambda x: (f(x), x)).groupByKey(numSplits) + return self.map(lambda x: (f(x), x)).groupByKey(numSplits) # TODO: pipe @@ -178,25 +117,19 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): - vals = self._jrdd.collect() - return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect()) + return PickleSerializer.loads(bytes(pickle)) - def reduce(self, f, serializer=None): + def reduce(self, f): """ - >>> import operator - >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(operator.add) + >>> from operator import add + >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add) 15 + >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) + 10 """ - serializer = serializer or self.ctx.defaultSerializer - loads = self.serializer.loads - dumps = serializer.dumps - def reduceFunction(x, acc): - if acc is None: - return loads(x) - else: - return f(loads(x), acc) - vals = self._pipe([reduceFunction, dumps], command="reduce").collect() - return reduce(f, (serializer.loads(x) for x in vals)) + vals = MappedRDD(self, f, command="reduce", preservesPartitioning=False).collect() + return reduce(f, vals) # TODO: fold @@ -216,36 +149,35 @@ class RDD(object): >>> sc.parallelize([2, 3, 4]).take(2) [2, 3] """ - vals = self._jrdd.take(num) - return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num)) + return PickleSerializer.loads(bytes(pickle)) def first(self): """ >>> sc.parallelize([2, 3, 4]).first() 2 """ - return self.serializer.loads(self.ctx.python_dump(self._jrdd.first())) + return PickleSerializer.loads(bytes(self.ctx.asPickle(self._jrdd.first()))) # TODO: saveAsTextFile # TODO: saveAsObjectFile + # Pair functions -class PairRDD(RDD): - - def __init__(self, jrdd, ctx, keySerializer=None, valSerializer=None): - RDD.__init__(self, jrdd, ctx) - self.keySerializer = keySerializer or ctx.defaultSerializer - self.valSerializer = valSerializer or ctx.defaultSerializer - self.serializer = \ - PairSerializer(self.keySerializer, self.valSerializer) - - def _builder(self, jrdd, ctx): - return PairRDD(jrdd, ctx, self.keySerializer, self.valSerializer) + def collectAsMap(self): + """ + >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() + >>> m[1] + 2 + >>> m[3] + 4 + """ + return dict(self.collect()) def reduceByKey(self, func, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(x.reduceByKey(lambda a, b: a + b).collect()) [('a', 2), ('b', 1)] """ @@ -259,90 +191,67 @@ class PairRDD(RDD): def join(self, other, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) - >>> y = sc.parallelizePairs([("a", 2), ("a", 3)]) - >>> x.join(y).collect() + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2), ("a", 3)]) + >>> sorted(x.join(y).collect()) [('a', (1, 2)), ('a', (1, 3))] - - Check that we get a PairRDD-like object back: - >>> assert x.join(y).join """ - assert self.keySerializer.name == other.keySerializer.name - if self.keySerializer.is_comparable: - return PairRDD(self._jrdd.join(other._jrdd), - self.ctx, self.keySerializer, - PairSerializer(self.valSerializer, other.valSerializer)) - else: - return python_join(self, other, numSplits) + return python_join(self, other, numSplits) def leftOuterJoin(self, other, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) - >>> y = sc.parallelizePairs([("a", 2)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) >>> sorted(x.leftOuterJoin(y).collect()) [('a', (1, 2)), ('b', (4, None))] """ - assert self.keySerializer.name == other.keySerializer.name - if self.keySerializer.is_comparable: - return PairRDD(self._jrdd.leftOuterJoin(other._jrdd), - self.ctx, self.keySerializer, - PairSerializer(self.valSerializer, - OptionSerializer(other.valSerializer))) - else: - return python_left_outer_join(self, other, numSplits) + return python_left_outer_join(self, other, numSplits) def rightOuterJoin(self, other, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) - >>> y = sc.parallelizePairs([("a", 2)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) >>> sorted(y.rightOuterJoin(x).collect()) [('a', (2, 1)), ('b', (None, 4))] """ - assert self.keySerializer.name == other.keySerializer.name - if self.keySerializer.is_comparable: - return PairRDD(self._jrdd.rightOuterJoin(other._jrdd), - self.ctx, self.keySerializer, - PairSerializer(OptionSerializer(self.valSerializer), - other.valSerializer)) - else: - return python_right_outer_join(self, other, numSplits) + return python_right_outer_join(self, other, numSplits) + + # TODO: pipelining + # TODO: optimizations + def shuffle(self, numSplits): + if numSplits is None: + numSplits = self.ctx.defaultParallelism + pipe_command = RDD._get_pipe_command('shuffle_map_step', []) + class_manifest = self._jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), + pipe_command, False, self.ctx.pythonExec, class_manifest) + partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) + jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) + jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) + # TODO: extract second value. + return RDD(jrdd, self.ctx) + + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, - numSplits=None, serializer=None): + numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> def f(x): return x >>> def add(a, b): return a + str(b) >>> sorted(x.combineByKey(str, add, add).collect()) [('a', '11'), ('b', '1')] """ - serializer = serializer or self.ctx.defaultSerializer if numSplits is None: numSplits = self.ctx.defaultParallelism - # Use hash() to create keys that are comparable in Java. - loadkv = self.serializer.loads - def pairify(kv): - # TODO: add method to deserialize only the key or value from - # a PairSerializer? - key = loadkv(kv)[0] - return (str(hash(key)), kv) - partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) - jrdd = self._pipePairs(pairify).partitionBy(partitioner) - pairified = PairRDD(jrdd, self.ctx, NopSerializer, self.serializer) - - loads = PairSerializer(NopSerializer, self.serializer).loads - dumpk = self.keySerializer.dumps - dumpc = serializer.dumps - - functions = [createCombiner, mergeValue, mergeCombiners, loads, dumpk, - dumpc] - jpairs = pairified._pipePairs(functions, "combine_by_key", - preservesPartitioning=True) - return PairRDD(jpairs, self.ctx, self.keySerializer, serializer) + shuffled = self.shuffle(numSplits) + functions = [createCombiner, mergeValue, mergeCombiners] + jpairs = shuffled._pipe(functions, "combine_by_key") + return RDD(jpairs, self.ctx) def groupByKey(self, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(x.groupByKey().collect()) [('a', [1, 1]), ('b', [1])] """ @@ -360,29 +269,15 @@ class PairRDD(RDD): return self.combineByKey(createCombiner, mergeValue, mergeCombiners, numSplits) - def collectAsMap(self): - """ - >>> m = sc.parallelizePairs([(1, 2), (3, 4)]).collectAsMap() - >>> m[1] - 2 - >>> m[3] - 4 - """ - m = self._jrdd.collectAsMap() - def loads(x): - (k, v) = x - return (self.keySerializer.loads(k), self.valSerializer.loads(v)) - return dict(loads(x) for x in m.items()) - - def flatMapValues(self, f, valSerializer=None): + def flatMapValues(self, f): flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) - return self.flatMapPairs(flat_map_fn, self.keySerializer, - valSerializer, True) + return self.flatMap(flat_map_fn) - def mapValues(self, f, valSerializer=None): + def mapValues(self, f): map_values_fn = lambda (k, v): (k, f(v)) - return self.mapPairs(map_values_fn, self.keySerializer, valSerializer, - True) + return self.map(map_values_fn, preservesPartitioning=True) + + # TODO: implement shuffle. # TODO: support varargs cogroup of several RDDs. def groupWith(self, other): @@ -390,20 +285,12 @@ class PairRDD(RDD): def cogroup(self, other, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) - >>> y = sc.parallelizePairs([("a", 2)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) >>> x.cogroup(y).collect() [('a', ([1], [2])), ('b', ([4], []))] """ - assert self.keySerializer.name == other.keySerializer.name - resultValSerializer = PairSerializer( - ArraySerializer(self.valSerializer), - ArraySerializer(other.valSerializer)) - if self.keySerializer.is_comparable: - return PairRDD(self._jrdd.cogroup(other._jrdd), - self.ctx, self.keySerializer, resultValSerializer) - else: - return python_cogroup(self, other, numSplits) + return python_cogroup(self, other, numSplits) # TODO: `lookup` is disabled because we can't make direct comparisons based # on the key; we need to compare the hash of the key to the hash of the @@ -413,44 +300,84 @@ class PairRDD(RDD): # TODO: file saving -class MappedRDDBase(object): - def __init__(self, prev, func, serializer, preservesPartitioning=False): - if isinstance(prev, MappedRDDBase) and not prev.is_cached: +class MappedRDD(RDD): + """ + Pipelined maps: + >>> rdd = sc.parallelize([1, 2, 3, 4]) + >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + + Pipelined reduces: + >>> from operator import add + >>> rdd.map(lambda x: 2 * x).reduce(add) + 20 + >>> rdd.flatMap(lambda x: [x, x]).reduce(add) + 20 + """ + def __init__(self, prev, func, preservesPartitioning=False, command='map'): + if isinstance(prev, MappedRDD) and not prev.is_cached: prev_func = prev.func - self.func = lambda x: func(prev_func(x)) + if command == 'reduce': + if prev.command == 'flatmap': + def flatmap_reduce_func(x, acc): + values = prev_func(x) + if values is None: + return acc + if not acc: + if len(values) == 1: + return values[0] + else: + return reduce(func, values[1:], values[0]) + else: + return reduce(func, values, acc) + self.func = flatmap_reduce_func + else: + def reduce_func(x, acc): + val = prev_func(x) + if not val: + return acc + if acc is None: + return val + else: + return func(val, acc) + self.func = reduce_func + else: + if prev.command == 'flatmap': + command = 'flatmap' + self.func = lambda x: (func(y) for y in prev_func(x)) + else: + self.func = lambda x: func(prev_func(x)) + self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning self._prev_jrdd = prev._prev_jrdd - self._prev_serializer = prev._prev_serializer + self.is_pipelined = True else: - self.func = func + if command == 'reduce': + def reduce_func(val, acc): + if acc is None: + return val + else: + return func(val, acc) + self.func = reduce_func + else: + self.func = func self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd - self._prev_serializer = prev.serializer - self.serializer = serializer or prev.ctx.defaultSerializer + self.is_pipelined = False self.is_cached = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None - - -class MappedRDD(MappedRDDBase, RDD): - """ - >>> rdd = sc.parallelize([1, 2, 3, 4]) - >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() - [4, 8, 12, 16] - >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() - [4, 8, 12, 16] - """ + self.command = command @property def _jrdd(self): if not self._jrdd_val: - udf = self.func - loads = self._prev_serializer.loads - dumps = self.serializer.dumps - func = lambda x: dumps(udf(loads(x))) - pipe_command = RDD._get_pipe_command("map", [func]) + funcs = [self.func] + pipe_command = RDD._get_pipe_command(self.command, funcs) class_manifest = self._prev_jrdd.classManifest() python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, self.preservesPartitioning, self.ctx.pythonExec, @@ -459,56 +386,11 @@ class MappedRDD(MappedRDDBase, RDD): return self._jrdd_val -class PairMappedRDD(MappedRDDBase, PairRDD): - """ - >>> rdd = sc.parallelize([1, 2, 3, 4]) - >>> rdd.mapPairs(lambda x: (x, x)) \\ - ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ - ... .collect() - [(2, 2), (4, 4), (6, 6), (8, 8)] - >>> rdd.mapPairs(lambda x: (x, x)) \\ - ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ - ... .map(lambda (x, _): x).collect() - [2, 4, 6, 8] - """ - - def __init__(self, prev, func, keySerializer=None, valSerializer=None, - preservesPartitioning=False): - self.keySerializer = keySerializer or prev.ctx.defaultSerializer - self.valSerializer = valSerializer or prev.ctx.defaultSerializer - serializer = PairSerializer(self.keySerializer, self.valSerializer) - MappedRDDBase.__init__(self, prev, func, serializer, - preservesPartitioning) - - @property - def _jrdd(self): - if not self._jrdd_val: - udf = self.func - loads = self._prev_serializer.loads - dumpk = self.keySerializer.dumps - dumpv = self.valSerializer.dumps - def func(x): - (k, v) = udf(loads(x)) - return (dumpk(k), dumpv(v)) - pipe_command = RDD._get_pipe_command("mapPairs", [func]) - class_manifest = self._prev_jrdd.classManifest() - self._jrdd_val = self.ctx.jvm.PythonPairRDD(self._prev_jrdd.rdd(), - pipe_command, self.preservesPartitioning, self.ctx.pythonExec, - class_manifest).asJavaPairRDD() - return self._jrdd_val - - def _test(): import doctest from pyspark.context import SparkContext - from pyspark.serializers import PickleSerializer, JSONSerializer globs = globals().copy() - globs['sc'] = SparkContext('local', 'PythonTest', - defaultSerializer=JSONSerializer) - doctest.testmod(globs=globs) - globs['sc'].stop() - globs['sc'] = SparkContext('local', 'PythonTest', - defaultSerializer=PickleSerializer) + globs['sc'] = SparkContext('local', 'PythonTest') doctest.testmod(globs=globs) globs['sc'].stop() diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index b113f5656b..7b3e6966e1 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -2,228 +2,35 @@ Data serialization methods. The Spark Python API is built on top of the Spark Java API. RDDs created in -Python are stored in Java as RDDs of Strings. Python objects are automatically -serialized/deserialized, so this representation is transparent to the end-user. - ------------------- -Serializer objects ------------------- - -`Serializer` objects are used to customize how an RDD's values are serialized. - -Each `Serializer` is a named tuple with four fields: - - - A `dumps` function, for serializing a Python object to a string. - - - A `loads` function, for deserializing a Python object from a string. - - - An `is_comparable` field, True if equal Python objects are serialized to - equal strings, and False otherwise. - - - A `name` field, used to identify the Serializer. Serializers are - compared for equality by comparing their names. - -The serializer's output should be base64-encoded. - ------------------------------------------------------------------- -`is_comparable`: comparing serialized representations for equality ------------------------------------------------------------------- - -If `is_comparable` is False, the serializer's representations of equal objects -are not required to be equal: - ->>> import pickle ->>> a = {1: 0, 9: 0} ->>> b = {9: 0, 1: 0} ->>> a == b -True ->>> pickle.dumps(a) == pickle.dumps(b) -False - -RDDs with comparable serializers can use native Java implementations of -operations like join() and distinct(), which may lead to better performance by -eliminating deserialization and Python comparisons. - -The default JSONSerializer produces comparable representations of common Python -data structures. - --------------------------------------- -Examples of serialized representations --------------------------------------- - -The RDD transformations that use Python UDFs are implemented in terms of -a modified `PipedRDD.pipe()` function. For each record `x` in the RDD, the -`pipe()` function pipes `x.toString()` to a Python worker process, which -deserializes the string into a Python object, executes user-defined functions, -and outputs serialized Python objects. - -The regular `toString()` method returns an ambiguous representation, due to the -way that Scala `Option` instances are printed: - ->>> from context import SparkContext ->>> sc = SparkContext("local", "SerializerDocs") ->>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) ->>> y = sc.parallelizePairs([("a", 2)]) - ->>> print y.rightOuterJoin(x)._jrdd.first().toString() -(ImEi,(Some(Mg==),MQ==)) - -In Java, preprocessing is performed to handle Option instances, so the Python -process receives unambiguous input: - ->>> print sc.python_dump(y.rightOuterJoin(x)._jrdd.first()) -(ImEi,(Mg==,MQ==)) - -The base64-encoding eliminates the need to escape newlines, parentheses and -other special characters. - ----------------------- -Serializer composition ----------------------- - -In order to handle nested structures, which could contain object serialized -with different serializers, the RDD module composes serializers. For example, -the serializers in the previous example are: - ->>> print x.serializer.name -PairSerializer - ->>> print y.serializer.name -PairSerializer - ->>> print y.rightOuterJoin(x).serializer.name -PairSerializer, JSONSerializer>> +Python are stored in Java as RDD[Array[Byte]]. Python objects are +automatically serialized/deserialized, so this representation is transparent to +the end-user. """ -from base64 import standard_b64encode, standard_b64decode from collections import namedtuple import cPickle -import simplejson - - -Serializer = namedtuple("Serializer", - ["dumps","loads", "is_comparable", "name"]) - - -NopSerializer = Serializer(str, str, True, "NopSerializer") +import struct -JSONSerializer = Serializer( - lambda obj: standard_b64encode(simplejson.dumps(obj, sort_keys=True, - separators=(',', ':'))), - lambda s: simplejson.loads(standard_b64decode(s)), - True, - "JSONSerializer" -) +Serializer = namedtuple("Serializer", ["dumps","loads"]) PickleSerializer = Serializer( - lambda obj: standard_b64encode(cPickle.dumps(obj)), - lambda s: cPickle.loads(standard_b64decode(s)), - False, - "PickleSerializer" -) - - -def OptionSerializer(serializer): - """ - >>> ser = OptionSerializer(NopSerializer) - >>> ser.loads(ser.dumps("Hello, World!")) - 'Hello, World!' - >>> ser.loads(ser.dumps(None)) is None - True - """ - none_placeholder = '*' - - def dumps(x): - if x is None: - return none_placeholder - else: - return serializer.dumps(x) - - def loads(x): - if x == none_placeholder: - return None - else: - return serializer.loads(x) - - name = "OptionSerializer<%s>" % serializer.name - return Serializer(dumps, loads, serializer.is_comparable, name) - - -def PairSerializer(keySerializer, valSerializer): - """ - Returns a Serializer for a (key, value) pair. - - >>> ser = PairSerializer(JSONSerializer, JSONSerializer) - >>> ser.loads(ser.dumps((1, 2))) - (1, 2) - - >>> ser = PairSerializer(JSONSerializer, ser) - >>> ser.loads(ser.dumps((1, (2, 3)))) - (1, (2, 3)) - """ - def loads(kv): - try: - (key, val) = kv[1:-1].split(',', 1) - key = keySerializer.loads(key) - val = valSerializer.loads(val) - return (key, val) - except: - print "Error in deserializing pair from '%s'" % str(kv) - raise - - def dumps(kv): - (key, val) = kv - return"(%s,%s)" % (keySerializer.dumps(key), valSerializer.dumps(val)) - is_comparable = \ - keySerializer.is_comparable and valSerializer.is_comparable - name = "PairSerializer<%s, %s>" % (keySerializer.name, valSerializer.name) - return Serializer(dumps, loads, is_comparable, name) - - -def ArraySerializer(serializer): - """ - >>> ser = ArraySerializer(JSONSerializer) - >>> ser.loads(ser.dumps([1, 2, 3, 4])) - [1, 2, 3, 4] - >>> ser = ArraySerializer(PairSerializer(JSONSerializer, PickleSerializer)) - >>> ser.loads(ser.dumps([('a', 1), ('b', 2)])) - [('a', 1), ('b', 2)] - >>> ser.loads(ser.dumps([('a', 1)])) - [('a', 1)] - >>> ser.loads(ser.dumps([])) - [] - """ - def dumps(arr): - if arr == []: - return '[]' - else: - return '[' + '|'.join(serializer.dumps(x) for x in arr) + ']' - - def loads(s): - if s == '[]': - return [] - items = s[1:-1] - if '|' in items: - items = items.split('|') - else: - items = [items] - return [serializer.loads(x) for x in items] - - name = "ArraySerializer<%s>" % serializer.name - return Serializer(dumps, loads, serializer.is_comparable, name) - - -# TODO: IntegerSerializer - - -# TODO: DoubleSerializer + lambda obj: cPickle.dumps(obj, -1), + cPickle.loads) -def _test(): - import doctest - doctest.testmod() +def dumps(obj, stream): + # TODO: determining the length of non-byte objects. + stream.write(struct.pack("!i", len(obj))) + stream.write(obj) -if __name__ == "__main__": - _test() +def loads(stream): + length = stream.read(4) + if length == "": + raise EOFError + length = struct.unpack("!i", length)[0] + obj = stream.read(length) + if obj == "": + raise EOFError + return obj diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 4c4b02fce4..21ff84fb17 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -6,9 +6,9 @@ from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. from pyspark.cloudpickle import CloudPickler +from pyspark.serializers import dumps, loads, PickleSerializer import cPickle - # Redirect stdout to stderr so that users must return values from functions. old_stdout = sys.stdout sys.stdout = sys.stderr @@ -19,58 +19,64 @@ def load_function(): def output(x): - for line in x.split("\n"): - old_stdout.write(line.rstrip("\r\n") + "\n") + dumps(x, old_stdout) def read_input(): - for line in sys.stdin: - yield line.rstrip("\r\n") - + try: + while True: + yield loads(sys.stdin) + except EOFError: + return def do_combine_by_key(): create_combiner = load_function() merge_value = load_function() merge_combiners = load_function() # TODO: not used. - depickler = load_function() - key_pickler = load_function() - combiner_pickler = load_function() combiners = {} - for line in read_input(): - # Discard the hashcode added in the Python combineByKey() method. - (key, value) = depickler(line)[1] + for obj in read_input(): + (key, value) = PickleSerializer.loads(obj) if key not in combiners: combiners[key] = create_combiner(value) else: combiners[key] = merge_value(combiners[key], value) for (key, combiner) in combiners.iteritems(): - output(key_pickler(key)) - output(combiner_pickler(combiner)) + output(PickleSerializer.dumps((key, combiner))) -def do_map(map_pairs=False): +def do_map(flat=False): f = load_function() - for line in read_input(): + for obj in read_input(): try: - out = f(line) + #from pickletools import dis + #print repr(obj) + #print dis(obj) + out = f(PickleSerializer.loads(obj)) if out is not None: - if map_pairs: + if flat: for x in out: - output(x) + output(PickleSerializer.dumps(x)) else: - output(out) + output(PickleSerializer.dumps(out)) except: - sys.stderr.write("Error processing line '%s'\n" % line) + sys.stderr.write("Error processing obj %s\n" % repr(obj)) raise +def do_shuffle_map_step(): + for obj in read_input(): + key = PickleSerializer.loads(obj)[1] + output(str(hash(key))) + output(obj) + + def do_reduce(): f = load_function() - dumps = load_function() acc = None - for line in read_input(): - acc = f(line, acc) - output(dumps(acc)) + for obj in read_input(): + acc = f(PickleSerializer.loads(obj), acc) + if acc is not None: + output(PickleSerializer.dumps(acc)) def do_echo(): @@ -80,13 +86,15 @@ def do_echo(): def main(): command = sys.stdin.readline().strip() if command == "map": - do_map(map_pairs=False) - elif command == "mapPairs": - do_map(map_pairs=True) + do_map(flat=False) + elif command == "flatmap": + do_map(flat=True) elif command == "combine_by_key": do_combine_by_key() elif command == "reduce": do_reduce() + elif command == "shuffle_map_step": + do_shuffle_map_step() elif command == "echo": do_echo() else: -- cgit v1.2.3 From 607b53abfca049e7d9139e2d29893a3bb252de19 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 22 Aug 2012 00:43:55 -0700 Subject: Use numpy in Python k-means example. --- .../main/scala/spark/api/python/PythonRDD.scala | 8 +++++++- pyspark/pyspark/examples/kmeans.py | 23 ++++++++-------------- pyspark/pyspark/rdd.py | 9 +++------ pyspark/pyspark/worker.py | 8 +++----- 4 files changed, 21 insertions(+), 27 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index b9a0168d18..93847e2f14 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -101,7 +101,13 @@ trait PythonRDDBase { stream.readFully(obj) obj } catch { - case eof: EOFException => { new Array[Byte](0) } + case eof: EOFException => { + val exitStatus = proc.waitFor() + if (exitStatus != 0) { + throw new Exception("Subprocess exited with status " + exitStatus) + } + new Array[Byte](0) + } case e => throw e } } diff --git a/pyspark/pyspark/examples/kmeans.py b/pyspark/pyspark/examples/kmeans.py index 0761d6e395..9cc366f03c 100644 --- a/pyspark/pyspark/examples/kmeans.py +++ b/pyspark/pyspark/examples/kmeans.py @@ -1,25 +1,18 @@ import sys from pyspark.context import SparkContext +from numpy import array, sum as np_sum def parseVector(line): - return [float(x) for x in line.split(' ')] - - -def addVec(x, y): - return [a + b for (a, b) in zip(x, y)] - - -def squaredDist(x, y): - return sum((a - b) ** 2 for (a, b) in zip(x, y)) + return array([float(x) for x in line.split(' ')]) def closestPoint(p, centers): bestIndex = 0 closest = float("+inf") for i in range(len(centers)): - tempDist = squaredDist(p, centers[i]) + tempDist = np_sum((p - centers[i]) ** 2) if tempDist < closest: closest = tempDist bestIndex = i @@ -41,14 +34,14 @@ if __name__ == "__main__": tempDist = 1.0 while tempDist > convergeDist: - closest = data.mapPairs( + closest = data.map( lambda p : (closestPoint(p, kPoints), (p, 1))) pointStats = closest.reduceByKey( - lambda (x1, y1), (x2, y2): (addVec(x1, x2), y1 + y2)) - newPoints = pointStats.mapPairs( - lambda (x, (y, z)): (x, [a / z for a in y])).collect() + lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) + newPoints = pointStats.map( + lambda (x, (y, z)): (x, y / z)).collect() - tempDist = sum(squaredDist(kPoints[x], y) for (x, y) in newPoints) + tempDist = sum(np_sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) for (x, y) in newPoints: kPoints[x] = y diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 8eccddc0a2..ff9c483032 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -71,7 +71,7 @@ class RDD(object): def takeSample(self, withReplacement, num, seed): vals = self._jrdd.takeSample(withReplacement, num, seed) - return [PickleSerializer.loads(x) for x in vals] + return [PickleSerializer.loads(bytes(x)) for x in vals] def union(self, other): """ @@ -218,17 +218,16 @@ class RDD(object): # TODO: pipelining # TODO: optimizations - def shuffle(self, numSplits): + def shuffle(self, numSplits, hashFunc=hash): if numSplits is None: numSplits = self.ctx.defaultParallelism - pipe_command = RDD._get_pipe_command('shuffle_map_step', []) + pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc]) class_manifest = self._jrdd.classManifest() python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, False, self.ctx.pythonExec, class_manifest) partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - # TODO: extract second value. return RDD(jrdd, self.ctx) @@ -277,8 +276,6 @@ class RDD(object): map_values_fn = lambda (k, v): (k, f(v)) return self.map(map_values_fn, preservesPartitioning=True) - # TODO: implement shuffle. - # TODO: support varargs cogroup of several RDDs. def groupWith(self, other): return self.cogroup(other) diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 21ff84fb17..b13ed5699a 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -48,9 +48,6 @@ def do_map(flat=False): f = load_function() for obj in read_input(): try: - #from pickletools import dis - #print repr(obj) - #print dis(obj) out = f(PickleSerializer.loads(obj)) if out is not None: if flat: @@ -64,9 +61,10 @@ def do_map(flat=False): def do_shuffle_map_step(): + hashFunc = load_function() for obj in read_input(): - key = PickleSerializer.loads(obj)[1] - output(str(hash(key))) + key = PickleSerializer.loads(obj)[0] + output(str(hashFunc(key))) output(obj) -- cgit v1.2.3 From 741899b21e4e6439459fcf4966076661c851ed07 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 16:26:06 -0700 Subject: Fix sendMessageReliablySync --- core/src/main/scala/spark/network/ConnectionManager.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 1a22d06cc8..66b822117f 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -14,7 +14,8 @@ import scala.collection.mutable.SynchronizedQueue import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer -import akka.dispatch.{Promise, ExecutionContext, Future} +import akka.dispatch.{Await, Promise, ExecutionContext, Future} +import akka.util.Duration case class ConnectionManagerId(host: String, port: Int) { def toSocketAddress() = new InetSocketAddress(host, port) @@ -325,7 +326,7 @@ class ConnectionManager(port: Int) extends Logging { } def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = { - sendMessageReliably(connectionManagerId, message)() + Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) } def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { -- cgit v1.2.3 From 06ef7c3d1bf8446d4d6ef8f3a055dd1e6d32ca3a Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 16:29:20 -0700 Subject: Less debug info --- core/src/main/scala/spark/network/ConnectionManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 66b822117f..bd0980029a 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -248,7 +248,7 @@ class ConnectionManager(port: Int) extends Logging { } private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { - logInfo("Handling [" + message + "] from [" + connectionManagerId + "]") + logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") message match { case bufferMessage: BufferMessage => { if (bufferMessage.hasAckId) { -- cgit v1.2.3 From 29e83f39e90b4d3cbeeb40d5ec0c19bd003c1840 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 18:16:25 -0700 Subject: Fix replication with MEMORY_ONLY_DESER_2 --- core/src/main/scala/spark/storage/BlockManager.scala | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index ff9914ae25..45f99717bc 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -364,6 +364,12 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val startTimeMs = System.currentTimeMillis var bytes: ByteBuffer = null + + // If we need to replicate the data, we'll want access to the values, but because our + // put will read the whole iterator, there will be no values left. For the case where + // the put serializes data, we'll remember the bytes, above; but for the case where + // it doesn't, such as MEMORY_ONLY_DESER, let's rely on the put returning an Iterator. + var valuesAfterPut: Iterator[Any] = null locker.getLock(blockId).synchronized { logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) @@ -391,7 +397,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // If only save to memory memoryStore.putValues(blockId, values, level) match { case Right(newBytes) => bytes = newBytes - case _ => + case Left(newIterator) => valuesAfterPut = newIterator } } else { // If only save to disk @@ -408,8 +414,13 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Replicate block if required if (level.replication > 1) { + // Serialize the block if not already done if (bytes == null) { - bytes = dataSerialize(values) // serialize the block if not already done + if (valuesAfterPut == null) { + throw new SparkException( + "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") + } + bytes = dataSerialize(valuesAfterPut) } replicate(blockId, bytes, level) } -- cgit v1.2.3 From 26dfd20c9a5139bd682a9902267b9d54a11ae20f Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 18:56:56 -0700 Subject: Detect disconnected slaves in StandaloneScheduler --- .../cluster/StandaloneSchedulerBackend.scala | 38 ++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 013671c1c8..83e7c6e036 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -2,13 +2,14 @@ package spark.scheduler.cluster import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import akka.actor.{Props, Actor, ActorRef, ActorSystem} +import akka.actor._ import akka.util.duration._ import akka.pattern.ask import spark.{SparkException, Logging, TaskState} import akka.dispatch.Await import java.util.concurrent.atomic.AtomicInteger +import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} /** * A standalone scheduler backend, which waits for standalone executors to connect to it through @@ -23,8 +24,16 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor { val slaveActor = new HashMap[String, ActorRef] + val slaveAddress = new HashMap[String, Address] val slaveHost = new HashMap[String, String] val freeCores = new HashMap[String, Int] + val actorToSlaveId = new HashMap[ActorRef, String] + val addressToSlaveId = new HashMap[Address, String] + + override def preStart() { + // Listen for remote client disconnection events, since they don't go through Akka's watch() + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + } def receive = { case RegisterSlave(slaveId, host, cores) => @@ -33,9 +42,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } else { logInfo("Registered slave: " + sender + " with ID " + slaveId) sender ! RegisteredSlave(sparkProperties) + context.watch(sender) slaveActor(slaveId) = sender slaveHost(slaveId) = host freeCores(slaveId) = cores + slaveAddress(slaveId) = sender.path.address + actorToSlaveId(sender) = slaveId + addressToSlaveId(sender.path.address) = slaveId totalCoreCount.addAndGet(cores) makeOffers() } @@ -54,7 +67,14 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor sender ! true context.stop(self) - // TODO: Deal with nodes disconnecting too! (Including decreasing totalCoreCount) + case Terminated(actor) => + actorToSlaveId.get(actor).foreach(removeSlave) + + case RemoteClientDisconnected(transport, address) => + addressToSlaveId.get(address).foreach(removeSlave) + + case RemoteClientShutdown(transport, address) => + addressToSlaveId.get(address).foreach(removeSlave) } // Make fake resource offers on all slaves @@ -76,6 +96,20 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor slaveActor(task.slaveId) ! LaunchTask(task) } } + + // Remove a disconnected slave from the cluster + def removeSlave(slaveId: String) { + logInfo("Slave " + slaveId + " disconnected, so removing it") + val numCores = freeCores(slaveId) + actorToSlaveId -= slaveActor(slaveId) + addressToSlaveId -= slaveAddress(slaveId) + slaveActor -= slaveId + slaveHost -= slaveId + freeCores -= slaveId + slaveHost -= slaveId + totalCoreCount.addAndGet(-numCores) + scheduler.slaveLost(slaveId) + } } var masterActor: ActorRef = null -- cgit v1.2.3 From 3c9c44a8d36c0f2dff40a50f1f1e3bc3dac7be7e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 19:37:43 -0700 Subject: More helpful log messages --- core/src/main/scala/spark/MapOutputTracker.scala | 3 ++- core/src/main/scala/spark/network/Connection.scala | 4 ++-- core/src/main/scala/spark/network/ConnectionManager.scala | 2 +- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 1 + 4 files changed, 6 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 0c97cd44a1..e249430905 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -116,7 +116,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = { val locs = bmAddresses.get(shuffleId) if (locs == null) { - logInfo("Don't have map outputs for shuffe " + shuffleId + ", fetching them") + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") fetching.synchronized { if (fetching.contains(shuffleId)) { // Someone else is fetching it; wait for them to be done @@ -158,6 +158,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg def incrementGeneration() { generationLock.synchronized { generation += 1 + logInfo("Increasing generation to " + generation) } } diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 451faee66e..da8aff9dd5 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -111,7 +111,7 @@ extends Connection(SocketChannel.open, selector_) { messages.synchronized{ /*messages += message*/ messages.enqueue(message) - logInfo("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") + logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") } } @@ -136,7 +136,7 @@ extends Connection(SocketChannel.open, selector_) { return chunk } /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ - logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) + logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) } } None diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index bd0980029a..0e764fff81 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -306,7 +306,7 @@ class ConnectionManager(port: Int) extends Logging { } val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection()) message.senderAddress = id.toSocketAddress() - logInfo("Sending [" + message + "] to [" + connectionManagerId + "]") + logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") /*connection.send(message)*/ sendMessageRequests.synchronized { sendMessageRequests += ((message, connection)) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index be24316e80..5412e8d8c0 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -88,6 +88,7 @@ class TaskSetManager( // Figure out the current map output tracker generation and set it on all tasks val generation = sched.mapOutputTracker.getGeneration + logInfo("Generation for " + taskSet.id + ": " + generation) for (t <- tasks) { t.generation = generation } -- cgit v1.2.3 From 117e3f8c8602c1303fa0e31840d85d1a7a6e3d9d Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 19:52:56 -0700 Subject: Fix a bug that was causing FetchFailedException not to be thrown --- core/src/main/scala/spark/BlockStoreShuffleFetcher.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index 3431ad2258..45a14c8290 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -48,8 +48,9 @@ class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { } } } catch { + // TODO: this is really ugly -- let's find a better way of throwing a FetchFailedException case be: BlockException => { - val regex = "shuffledid_([0-9]*)_([0-9]*)_([0-9]]*)".r + val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]]*)".r be.blockId match { case regex(sId, mId, rId) => { val address = addresses(mId.toInt) -- cgit v1.2.3 From 69c2ab04083972e4ecf1393ffd0cb0acb56b4f7d Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 20:00:58 -0700 Subject: logging --- core/src/main/scala/spark/executor/Executor.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 9e335c25f7..dba209ac27 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -63,6 +63,7 @@ class Executor extends Logging { Thread.currentThread.setContextClassLoader(classLoader) Accumulators.clear() val task = ser.deserialize[Task[Any]](serializedTask, classLoader) + logInfo("Its generation is " + task.generation) env.mapOutputTracker.updateGeneration(task.generation) val value = task.run(taskId.toInt) val accumUpdates = Accumulators.values -- cgit v1.2.3 From b914cd0dfa21b615c29d2ce935f623f209afa8f4 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 20:07:59 -0700 Subject: Serialize generation correctly in ShuffleMapTask --- core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 2 ++ 1 file changed, 2 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index f78e0e5fb2..73479bff01 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -90,6 +90,7 @@ class ShuffleMapTask( out.writeInt(bytes.length) out.write(bytes) out.writeInt(partition) + out.writeLong(generation) out.writeObject(split) } @@ -102,6 +103,7 @@ class ShuffleMapTask( rdd = rdd_ dep = dep_ partition = in.readInt() + generation = in.readLong() split = in.readObject().asInstanceOf[Split] } -- cgit v1.2.3 From f79a1e4d2a8643157136de69b8d7de84f0034712 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 25 Aug 2012 13:59:01 -0700 Subject: Add broadcast variables to Python API. --- .../main/scala/spark/api/python/PythonRDD.scala | 43 ++++++++++++-------- pyspark/pyspark/broadcast.py | 46 ++++++++++++++++++++++ pyspark/pyspark/context.py | 17 ++++++-- pyspark/pyspark/rdd.py | 27 ++++++++----- pyspark/pyspark/worker.py | 6 +++ 5 files changed, 110 insertions(+), 29 deletions(-) create mode 100644 pyspark/pyspark/broadcast.py (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 93847e2f14..5163812df4 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -7,14 +7,13 @@ import scala.collection.JavaConversions._ import scala.io.Source import spark._ import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} -import scala.{collection, Some} -import collection.parallel.mutable +import broadcast.Broadcast import scala.collection -import scala.Some trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], - command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[Array[Byte]] = { + command: Seq[String], parent: RDD[T], pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]): Iterator[Array[Byte]] = { val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) @@ -42,11 +41,18 @@ trait PythonRDDBase { override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) + val dOut = new DataOutputStream(proc.getOutputStream) + out.println(broadcastVars.length) + for (broadcast <- broadcastVars) { + out.print(broadcast.uuid.toString) + dOut.writeInt(broadcast.value.length) + dOut.write(broadcast.value) + dOut.flush() + } for (elem <- command) { out.println(elem) } out.flush() - val dOut = new DataOutputStream(proc.getOutputStream) for (elem <- parent.iterator(split)) { if (elem.isInstanceOf[Array[Byte]]) { val arr = elem.asInstanceOf[Array[Byte]] @@ -121,16 +127,17 @@ trait PythonRDDBase { class PythonRDD[T: ClassManifest]( parent: RDD[T], command: Seq[String], envVars: Map[String, String], - preservePartitoning: Boolean, pythonExec: String) + preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) with PythonRDDBase { - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = - this(parent, command, Map(), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, + pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) override def splits = parent.splits @@ -139,23 +146,25 @@ class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None override def compute(split: Split): Iterator[Array[Byte]] = - compute(split, envVars, command, parent, pythonExec) + compute(split, envVars, command, parent, pythonExec, broadcastVars) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } class PythonPairRDD[T: ClassManifest] ( parent: RDD[T], command: Seq[String], envVars: Map[String, String], - preservePartitoning: Boolean, pythonExec: String) + preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase { - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = - this(parent, command, Map(), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, + pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) override def splits = parent.splits @@ -164,7 +173,7 @@ class PythonPairRDD[T: ClassManifest] ( override val partitioner = if (preservePartitoning) parent.partitioner else None override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = { - compute(split, envVars, command, parent, pythonExec).grouped(2).map { + compute(split, envVars, command, parent, pythonExec, broadcastVars).grouped(2).map { case Seq(a, b) => (a, b) case x => throw new Exception("PythonPairRDD: unexpected value: " + x) } diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py new file mode 100644 index 0000000000..1ea17d59af --- /dev/null +++ b/pyspark/pyspark/broadcast.py @@ -0,0 +1,46 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> b = sc.broadcast([1, 2, 3, 4, 5]) +>>> b.value +[1, 2, 3, 4, 5] + +>>> from pyspark.broadcast import _broadcastRegistry +>>> _broadcastRegistry[b.uuid] = b +>>> from cPickle import dumps, loads +>>> loads(dumps(b)).value +[1, 2, 3, 4, 5] + +>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() +[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] +""" +# Holds broadcasted data received from Java, keyed by UUID. +_broadcastRegistry = {} + + +def _from_uuid(uuid): + from pyspark.broadcast import _broadcastRegistry + if uuid not in _broadcastRegistry: + raise Exception("Broadcast variable '%s' not loaded!" % uuid) + return _broadcastRegistry[uuid] + + +class Broadcast(object): + def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None): + self.value = value + self.uuid = uuid + self._jbroadcast = java_broadcast + self._pickle_registry = pickle_registry + + def __reduce__(self): + self._pickle_registry.add(self) + return (_from_uuid, (self.uuid, )) + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index ac7e4057e9..6f87206665 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -2,6 +2,7 @@ import os import atexit from tempfile import NamedTemporaryFile +from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, dumps from pyspark.rdd import RDD @@ -24,6 +25,11 @@ class SparkContext(object): self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() self.pythonExec = pythonExec + # Broadcast's __reduce__ method stores Broadcast instances here. + # This allows other code to determine which Broadcast instances have + # been pickled, so it can determine which Java broadcast objects to + # send. + self._pickled_broadcast_vars = set() def __del__(self): if self._jsc: @@ -52,7 +58,12 @@ class SparkContext(object): jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) - def textFile(self, name, numSlices=None): - numSlices = numSlices or self.defaultParallelism - jrdd = self._jsc.textFile(name, numSlices) + def textFile(self, name, minSplits=None): + minSplits = minSplits or min(self.defaultParallelism, 2) + jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + + def broadcast(self, value): + jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value))) + return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, + self._pickled_broadcast_vars) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index af7703fdfc..4459095391 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -6,6 +6,8 @@ from pyspark.serializers import PickleSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup +from py4j.java_collections import ListConverter + class RDD(object): @@ -15,11 +17,15 @@ class RDD(object): self.ctx = ctx @classmethod - def _get_pipe_command(cls, command, functions): + def _get_pipe_command(cls, ctx, command, functions): worker_args = [command] for f in functions: worker_args.append(b64enc(cloudpickle.dumps(f))) - return " ".join(worker_args) + broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars] + broadcast_vars = ListConverter().convert(broadcast_vars, + ctx.gateway._gateway_client) + ctx._pickled_broadcast_vars.clear() + return (" ".join(worker_args), broadcast_vars) def cache(self): self.is_cached = True @@ -52,9 +58,10 @@ class RDD(object): def _pipe(self, functions, command): class_manifest = self._jrdd.classManifest() - pipe_command = RDD._get_pipe_command(command, functions) + (pipe_command, broadcast_vars) = \ + RDD._get_pipe_command(self.ctx, command, functions) python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, - False, self.ctx.pythonExec, class_manifest) + False, self.ctx.pythonExec, broadcast_vars, class_manifest) return python_rdd.asJavaRDD() def distinct(self): @@ -249,10 +256,12 @@ class RDD(object): def shuffle(self, numSplits, hashFunc=hash): if numSplits is None: numSplits = self.ctx.defaultParallelism - pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc]) + (pipe_command, broadcast_vars) = \ + RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc]) class_manifest = self._jrdd.classManifest() python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), - pipe_command, False, self.ctx.pythonExec, class_manifest) + pipe_command, False, self.ctx.pythonExec, broadcast_vars, + class_manifest) partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) @@ -360,12 +369,12 @@ class PipelinedRDD(RDD): @property def _jrdd(self): if not self._jrdd_val: - funcs = [self.func] - pipe_command = RDD._get_pipe_command("pipeline", funcs) + (pipe_command, broadcast_vars) = \ + RDD._get_pipe_command(self.ctx, "pipeline", [self.func]) class_manifest = self._prev_jrdd.classManifest() python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, self.preservesPartitioning, self.ctx.pythonExec, - class_manifest) + broadcast_vars, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 76b09918e7..7402897ac8 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -5,6 +5,7 @@ import sys from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. +from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.serializers import dumps, loads, PickleSerializer import cPickle @@ -63,6 +64,11 @@ def do_shuffle_map_step(): def main(): + num_broadcast_variables = int(sys.stdin.readline().strip()) + for _ in range(num_broadcast_variables): + uuid = sys.stdin.read(36) + value = loads(sys.stdin) + _broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value)) command = sys.stdin.readline().strip() if command == "pipeline": do_pipeline() -- cgit v1.2.3 From 200d248dcc5903295296bf897211cf543b37f8c1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 25 Aug 2012 16:46:07 -0700 Subject: Simplify Python worker; pipeline the map step of partitionBy(). --- .../main/scala/spark/api/python/PythonRDD.scala | 34 +++-------- pyspark/pyspark/context.py | 9 ++- pyspark/pyspark/rdd.py | 70 +++++++++------------- pyspark/pyspark/serializers.py | 23 ++----- pyspark/pyspark/worker.py | 50 +++++----------- 5 files changed, 59 insertions(+), 127 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 5163812df4..b9091fd436 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -151,38 +151,18 @@ class PythonRDD[T: ClassManifest]( val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } -class PythonPairRDD[T: ClassManifest] ( - parent: RDD[T], command: Seq[String], envVars: Map[String, String], - preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) - extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase { - - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, - pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = - this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) - - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) - - override def splits = parent.splits - - override val dependencies = List(new OneToOneDependency(parent)) - - override val partitioner = if (preservePartitoning) parent.partitioner else None - - override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = { - compute(split, envVars, command, parent, pythonExec, broadcastVars).grouped(2).map { +private class PairwiseRDD(prev: RDD[Array[Byte]]) extends + RDD[(Array[Byte], Array[Byte])](prev.context) { + override def splits = prev.splits + override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split) = + prev.iterator(split).grouped(2).map { case Seq(a, b) => (a, b) - case x => throw new Exception("PythonPairRDD: unexpected value: " + x) + case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } - } - val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } - object PythonRDD { /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 6f87206665..b8490019e3 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway -from pyspark.serializers import PickleSerializer, dumps +from pyspark.serializers import dump_pickle, write_with_length from pyspark.rdd import RDD @@ -16,9 +16,8 @@ class SparkContext(object): asPickle = jvm.spark.api.python.PythonRDD.asPickle arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle - def __init__(self, master, name, defaultParallelism=None, - pythonExec='python'): + pythonExec='python'): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) @@ -52,7 +51,7 @@ class SparkContext(object): # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) for x in c: - dumps(PickleSerializer.dumps(x), tempFile) + write_with_length(dump_pickle(x), tempFile) tempFile.close() atexit.register(lambda: os.unlink(tempFile.name)) jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) @@ -64,6 +63,6 @@ class SparkContext(object): return RDD(jrdd, self) def broadcast(self, value): - jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value))) + jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, self._pickled_broadcast_vars) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 3528b8f308..21e822ba9f 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -3,7 +3,7 @@ from collections import Counter from itertools import chain, ifilter, imap from pyspark import cloudpickle -from pyspark.serializers import PickleSerializer +from pyspark.serializers import dump_pickle, load_pickle from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup @@ -17,17 +17,6 @@ class RDD(object): self.is_cached = False self.ctx = ctx - @classmethod - def _get_pipe_command(cls, ctx, command, functions): - worker_args = [command] - for f in functions: - worker_args.append(b64enc(cloudpickle.dumps(f))) - broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars] - broadcast_vars = ListConverter().convert(broadcast_vars, - ctx.gateway._gateway_client) - ctx._pickled_broadcast_vars.clear() - return (" ".join(worker_args), broadcast_vars) - def cache(self): self.is_cached = True self._jrdd.cache() @@ -66,14 +55,6 @@ class RDD(object): def func(iterator): return ifilter(f, iterator) return self.mapPartitions(func) - def _pipe(self, functions, command): - class_manifest = self._jrdd.classManifest() - (pipe_command, broadcast_vars) = \ - RDD._get_pipe_command(self.ctx, command, functions) - python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, - False, self.ctx.pythonExec, broadcast_vars, class_manifest) - return python_rdd.asJavaRDD() - def distinct(self): """ >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) @@ -89,7 +70,7 @@ class RDD(object): def takeSample(self, withReplacement, num, seed): vals = self._jrdd.takeSample(withReplacement, num, seed) - return [PickleSerializer.loads(bytes(x)) for x in vals] + return [load_pickle(bytes(x)) for x in vals] def union(self, other): """ @@ -148,7 +129,7 @@ class RDD(object): def collect(self): pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect()) - return PickleSerializer.loads(bytes(pickle)) + return load_pickle(bytes(pickle)) def reduce(self, f): """ @@ -216,19 +197,17 @@ class RDD(object): [2, 3] """ pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num)) - return PickleSerializer.loads(bytes(pickle)) + return load_pickle(bytes(pickle)) def first(self): """ >>> sc.parallelize([2, 3, 4]).first() 2 """ - return PickleSerializer.loads(bytes(self.ctx.asPickle(self._jrdd.first()))) + return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first()))) # TODO: saveAsTextFile - # TODO: saveAsObjectFile - # Pair functions def collectAsMap(self): @@ -303,19 +282,18 @@ class RDD(object): """ return python_right_outer_join(self, other, numSplits) - # TODO: pipelining - # TODO: optimizations def partitionBy(self, numSplits, hashFunc=hash): if numSplits is None: numSplits = self.ctx.defaultParallelism - (pipe_command, broadcast_vars) = \ - RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc]) - class_manifest = self._jrdd.classManifest() - python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), - pipe_command, False, self.ctx.pythonExec, broadcast_vars, - class_manifest) + def add_shuffle_key(iterator): + for (k, v) in iterator: + yield str(hashFunc(k)) + yield dump_pickle((k, v)) + keyed = PipelinedRDD(self, add_shuffle_key) + keyed._bypass_serializer = True + pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) - jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) + jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) return RDD(jrdd, self.ctx) @@ -430,17 +408,23 @@ class PipelinedRDD(RDD): self.ctx = prev.ctx self.prev = prev self._jrdd_val = None + self._bypass_serializer = False @property def _jrdd(self): - if not self._jrdd_val: - (pipe_command, broadcast_vars) = \ - RDD._get_pipe_command(self.ctx, "pipeline", [self.func]) - class_manifest = self._prev_jrdd.classManifest() - python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, class_manifest) - self._jrdd_val = python_rdd.asJavaRDD() + if self._jrdd_val: + return self._jrdd_val + funcs = [self.func, self._bypass_serializer] + pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in funcs) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], + self.ctx.gateway._gateway_client) + self.ctx._pickled_broadcast_vars.clear() + class_manifest = self._prev_jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), + pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + broadcast_vars, class_manifest) + self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 7b3e6966e1..faa1e683c7 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -1,31 +1,20 @@ -""" -Data serialization methods. - -The Spark Python API is built on top of the Spark Java API. RDDs created in -Python are stored in Java as RDD[Array[Byte]]. Python objects are -automatically serialized/deserialized, so this representation is transparent to -the end-user. -""" -from collections import namedtuple -import cPickle import struct +import cPickle -Serializer = namedtuple("Serializer", ["dumps","loads"]) +def dump_pickle(obj): + return cPickle.dumps(obj, 2) -PickleSerializer = Serializer( - lambda obj: cPickle.dumps(obj, -1), - cPickle.loads) +load_pickle = cPickle.loads -def dumps(obj, stream): - # TODO: determining the length of non-byte objects. +def write_with_length(obj, stream): stream.write(struct.pack("!i", len(obj))) stream.write(obj) -def loads(stream): +def read_with_length(stream): length = stream.read(4) if length == "": raise EOFError diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 0f90c6ff46..a9ed71892f 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -7,61 +7,41 @@ from base64 import standard_b64decode # copy_reg module. from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import dumps, loads, PickleSerializer -import cPickle +from pyspark.serializers import write_with_length, read_with_length, \ + dump_pickle, load_pickle + # Redirect stdout to stderr so that users must return values from functions. old_stdout = sys.stdout sys.stdout = sys.stderr -def load_function(): - return cPickle.loads(standard_b64decode(sys.stdin.readline().strip())) - - -def output(x): - dumps(x, old_stdout) +def load_obj(): + return load_pickle(standard_b64decode(sys.stdin.readline().strip())) def read_input(): try: while True: - yield cPickle.loads(loads(sys.stdin)) + yield load_pickle(read_with_length(sys.stdin)) except EOFError: return -def do_pipeline(): - f = load_function() - for obj in f(read_input()): - output(PickleSerializer.dumps(obj)) - - -def do_shuffle_map_step(): - hashFunc = load_function() - while True: - try: - pickled = loads(sys.stdin) - except EOFError: - return - key = cPickle.loads(pickled)[0] - output(str(hashFunc(key))) - output(pickled) - - def main(): num_broadcast_variables = int(sys.stdin.readline().strip()) for _ in range(num_broadcast_variables): uuid = sys.stdin.read(36) - value = loads(sys.stdin) - _broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value)) - command = sys.stdin.readline().strip() - if command == "pipeline": - do_pipeline() - elif command == "shuffle_map_step": - do_shuffle_map_step() + value = read_with_length(sys.stdin) + _broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value)) + func = load_obj() + bypassSerializer = load_obj() + if bypassSerializer: + dumps = lambda x: x else: - raise Exception("Unsupported command %s" % command) + dumps = dump_pickle + for obj in func(read_input()): + write_with_length(dumps(obj), old_stdout) if __name__ == '__main__': -- cgit v1.2.3 From bff6a46359131a8f9bc38b93149b22baa7c711cd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 25 Aug 2012 18:00:25 -0700 Subject: Add pipe(), saveAsTextFile(), sc.union() to Python API. --- .../main/scala/spark/api/python/PythonRDD.scala | 8 +++++-- pyspark/pyspark/context.py | 14 ++++++------ pyspark/pyspark/rdd.py | 25 ++++++++++++++++++++-- 3 files changed, 37 insertions(+), 10 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index b9091fd436..4d3bdb3963 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -9,6 +9,7 @@ import spark._ import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import broadcast.Broadcast import scala.collection +import java.nio.charset.Charset trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], @@ -238,9 +239,12 @@ private object Pickle { val MARK : Byte = '(' val APPENDS : Byte = 'e' } -class ExtractValue extends spark.api.java.function.Function[(Array[Byte], - Array[Byte]), Array[Byte]] { +private class ExtractValue extends spark.api.java.function.Function[(Array[Byte], + Array[Byte]), Array[Byte]] { override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2 +} +private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { + override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index b8490019e3..04932c93f2 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -7,6 +7,8 @@ from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length from pyspark.rdd import RDD +from py4j.java_collections import ListConverter + class SparkContext(object): @@ -39,12 +41,6 @@ class SparkContext(object): self._jsc = None def parallelize(self, c, numSlices=None): - """ - >>> sc = SparkContext("local", "test") - >>> rdd = sc.parallelize([(1, 2), (3, 4)]) - >>> rdd.collect() - [(1, 2), (3, 4)] - """ numSlices = numSlices or self.defaultParallelism # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized @@ -62,6 +58,12 @@ class SparkContext(object): jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + def union(self, rdds): + first = rdds[0]._jrdd + rest = [x._jrdd for x in rdds[1:]] + rest = ListConverter().convert(rest, self.gateway._gateway_client) + return RDD(self._jsc.union(first, rest), self) + def broadcast(self, value): jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 21e822ba9f..8477f6dd02 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,6 +1,9 @@ from base64 import standard_b64encode as b64enc from collections import Counter from itertools import chain, ifilter, imap +import shlex +from subprocess import Popen, PIPE +from threading import Thread from pyspark import cloudpickle from pyspark.serializers import dump_pickle, load_pickle @@ -118,7 +121,20 @@ class RDD(object): """ return self.map(lambda x: (f(x), x)).groupByKey(numSplits) - # TODO: pipe + def pipe(self, command, env={}): + """ + >>> sc.parallelize([1, 2, 3]).pipe('cat').collect() + ['1', '2', '3'] + """ + def func(iterator): + pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) + def pipe_objs(out): + for obj in iterator: + out.write(str(obj).rstrip('\n') + '\n') + out.close() + Thread(target=pipe_objs, args=[pipe.stdin]).start() + return (x.rstrip('\n') for x in pipe.stdout) + return self.mapPartitions(func) def foreach(self, f): """ @@ -206,7 +222,12 @@ class RDD(object): """ return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first()))) - # TODO: saveAsTextFile + def saveAsTextFile(self, path): + def func(iterator): + return (str(x).encode("utf-8") for x in iterator) + keyed = PipelinedRDD(self, func) + keyed._bypass_serializer = True + keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path) # Pair functions -- cgit v1.2.3 From b4a2214218eeb9ebd95b59d88c2212fe717efd9e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 27 Aug 2012 22:49:29 -0700 Subject: More fault tolerance fixes to catch lost tasks --- core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala | 2 ++ .../main/scala/spark/scheduler/cluster/TaskSetManager.scala | 13 +++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala index 0fc1d8ed30..65e59841a9 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala @@ -20,6 +20,8 @@ class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: def successful: Boolean = finished && !failed + def running: Boolean = !finished + def duration: Long = { if (!finished) { throw new UnsupportedOperationException("duration() called on unfinished tasks") diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 5412e8d8c0..17317e80df 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -265,6 +265,11 @@ class TaskSetManager( def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } val index = info.index info.markFailed() if (!finished(index)) { @@ -341,7 +346,7 @@ class TaskSetManager( } def hostLost(hostname: String) { - logInfo("Re-queueing tasks for " + hostname) + logInfo("Re-queueing tasks for " + hostname + " from TaskSet " + taskSet.id) // If some task has preferred locations only on hostname, put it in the no-prefs list // to avoid the wait from delay scheduling for (index <- getPendingTasksForHost(hostname)) { @@ -350,7 +355,7 @@ class TaskSetManager( pendingTasksWithNoPrefs += index } } - // Also re-enqueue any tasks that ran on the failed host if this is a shuffle map stage + // Re-enqueue any tasks that ran on the failed host if this is a shuffle map stage if (tasks(0).isInstanceOf[ShuffleMapTask]) { for ((tid, info) <- taskInfos if info.host == hostname) { val index = taskInfos(tid).index @@ -365,6 +370,10 @@ class TaskSetManager( } } } + // Also re-enqueue any tasks that were running on the node + for ((tid, info) <- taskInfos if info.running && info.host == hostname) { + taskLost(tid, TaskState.KILLED, null) + } } /** -- cgit v1.2.3 From 17af2df0cdcdc4f02013bd7b4351e0a9d9ee9b25 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 27 Aug 2012 23:07:32 -0700 Subject: Log levels --- core/src/main/scala/spark/MapOutputTracker.scala | 2 +- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index e249430905..de23eb6f48 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -158,7 +158,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg def incrementGeneration() { generationLock.synchronized { generation += 1 - logInfo("Increasing generation to " + generation) + logDebug("Increasing generation to " + generation) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 17317e80df..5a7df6040c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -88,7 +88,7 @@ class TaskSetManager( // Figure out the current map output tracker generation and set it on all tasks val generation = sched.mapOutputTracker.getGeneration - logInfo("Generation for " + taskSet.id + ": " + generation) + logDebug("Generation for " + taskSet.id + ": " + generation) for (t <- tasks) { t.generation = generation } -- cgit v1.2.3 From 4db3a967669a53de4c4b79b4c0b70daa5accb682 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 29 Aug 2012 13:04:01 -0700 Subject: Made minor changes to reduce compilation errors in Eclipse. Twirl stuff still does not compile in Eclipse. --- .../src/main/scala/spark/network/ConnectionManager.scala | 16 +++++++++++++--- .../main/scala/spark/network/ConnectionManagerTest.scala | 5 ++++- 2 files changed, 17 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 0e764fff81..2bb5f5fc6b 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -16,6 +16,7 @@ import scala.collection.mutable.ArrayBuffer import akka.dispatch.{Await, Promise, ExecutionContext, Future} import akka.util.Duration +import akka.util.duration._ case class ConnectionManagerId(host: String, port: Int) { def toSocketAddress() = new InetSocketAddress(host, port) @@ -403,7 +404,10 @@ object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) val finishTime = System.currentTimeMillis val mb = size * count / 1024.0 / 1024.0 @@ -430,7 +434,10 @@ object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) val finishTime = System.currentTimeMillis val ms = finishTime - startTime @@ -457,7 +464,10 @@ object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) val finishTime = System.currentTimeMillis Thread.sleep(1000) val mb = size * count / 1024.0 / 1024.0 diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala index 5d21bb793f..555b3454ee 100644 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala @@ -8,6 +8,9 @@ import scala.io.Source import java.nio.ByteBuffer import java.net.InetAddress +import akka.dispatch.Await +import akka.util.duration._ + object ConnectionManagerTest extends Logging{ def main(args: Array[String]) { if (args.length < 2) { @@ -53,7 +56,7 @@ object ConnectionManagerTest extends Logging{ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) }) - val results = futures.map(f => f()) + val results = futures.map(f => Await.result(f, 1.second)) val finishTime = System.currentTimeMillis Thread.sleep(5000) -- cgit v1.2.3 From c4366eb76425d1c6aeaa7df750a2681a0da75db8 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 31 Aug 2012 00:34:24 +0000 Subject: Fixes to ShuffleFetcher --- .../scala/spark/BlockStoreShuffleFetcher.scala | 41 +++++++++------------- 1 file changed, 17 insertions(+), 24 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index 45a14c8290..0bbdb4e432 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -32,36 +32,29 @@ class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { (address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId))) } - try { - for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) { - blockOption match { - case Some(block) => { - val values = block - for(value <- values) { - val v = value.asInstanceOf[(K, V)] - func(v._1, v._2) - } - } - case None => { - throw new BlockException(blockId, "Did not get block " + blockId) + for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) { + blockOption match { + case Some(block) => { + val values = block + for(value <- values) { + val v = value.asInstanceOf[(K, V)] + func(v._1, v._2) } } - } - } catch { - // TODO: this is really ugly -- let's find a better way of throwing a FetchFailedException - case be: BlockException => { - val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]]*)".r - be.blockId match { - case regex(sId, mId, rId) => { - val address = addresses(mId.toInt) - throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be) - } - case _ => { - throw be + case None => { + val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]*)".r + blockId match { + case regex(shufId, mapId, reduceId) => + val addr = addresses(mapId.toInt) + throw new FetchFailedException(addr, shufId.toInt, mapId.toInt, reduceId.toInt, null) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block") } } } } + logDebug("Fetching and merging outputs of shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) } -- cgit v1.2.3 From 1b3e3352ebfed40881d534cd3096d4b6428c24d4 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 30 Aug 2012 17:59:25 -0700 Subject: Deserialize multi-get results in the caller's thread. This fixes an issue with shared buffers with the KryoSerializer. --- core/src/main/scala/spark/storage/BlockManager.scala | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 45f99717bc..e9197f7169 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -272,11 +272,15 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val totalBlocks = blocksByAddress.map(_._2.size).sum logDebug("Getting " + totalBlocks + " blocks") var startTime = System.currentTimeMillis - val results = new LinkedBlockingQueue[(String, Option[Iterator[Any]])] val localBlockIds = new ArrayBuffer[String]() val remoteBlockIds = new ArrayBuffer[String]() val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]() + // A queue to hold our results. Because we want all the deserializing the happen in the + // caller's thread, this will actually hold functions to produce the Iterator for each block. + // For local blocks we'll have an iterator already, while for remote ones we'll deserialize. + val results = new LinkedBlockingQueue[(String, Option[() => Iterator[Any]])] + // Split local and remote blocks for ((address, blockIds) <- blocksByAddress) { if (address == blockManagerId) { @@ -302,10 +306,8 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new SparkException( "Unexpected message " + blockMessage.getType + " received from " + cmId) } - val buffer = blockMessage.getData val blockId = blockMessage.getId - val block = dataDeserialize(buffer) - results.put((blockId, Some(block))) + results.put((blockId, Some(() => dataDeserialize(blockMessage.getData)))) logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) }) } @@ -323,9 +325,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Get the local blocks while remote blocks are being fetched startTime = System.currentTimeMillis localBlockIds.foreach(id => { - get(id) match { + getLocal(id) match { case Some(block) => { - results.put((id, Some(block))) + results.put((id, Some(() => block))) logDebug("Got local block " + id) } case None => { @@ -343,7 +345,8 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 - results.take() + val (blockId, functionOption) = results.take() + (blockId, functionOption.map(_.apply())) } } } -- cgit v1.2.3 From 101ae493e26693146114ac01d50d411f5b2e0762 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 30 Aug 2012 22:24:14 -0700 Subject: Replicate serialized blocks properly, without sharing a ByteBuffer. --- core/src/main/scala/spark/storage/BlockManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index e9197f7169..8a013230da 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -456,7 +456,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // data is already serialized and ready for sending val replicationFuture = if (level.replication > 1) { Future { - replicate(blockId, bytes, level) + replicate(blockId, bytes.duplicate(), level) } } else { null -- cgit v1.2.3 From 113277549c5ee1bcd58c7cebc365d28d92b74b4a Mon Sep 17 00:00:00 2001 From: root Date: Fri, 31 Aug 2012 05:39:35 +0000 Subject: Really fixed the replication-3 issue. The problem was a few buffers not being rewound. --- core/src/main/scala/spark/storage/BlockManager.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 8a013230da..f2d9499bad 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -455,8 +455,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Initiate the replication before storing it locally. This is faster as // data is already serialized and ready for sending val replicationFuture = if (level.replication > 1) { + val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper Future { - replicate(blockId, bytes.duplicate(), level) + replicate(blockId, bufferView, level) } } else { null @@ -514,15 +515,16 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m var peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) for (peer: BlockManagerId <- peers) { val start = System.nanoTime + data.rewind() logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " - + data.array().length + " Bytes. To node: " + peer) + + data.limit() + " Bytes. To node: " + peer) if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), new ConnectionManagerId(peer.ip, peer.port))) { logError("Failed to call syncPutBlock to " + peer) } logDebug("Replicated BlockId " + blockId + " once used " + (System.nanoTime - start) / 1e6 + " s; The size of the data is " + - data.array().length + " bytes.") + data.limit() + " bytes.") } } -- cgit v1.2.3 From c42e7ac2822f697a355650a70379d9e4ce2022c0 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 1 Sep 2012 04:31:11 +0000 Subject: More block manager fixes --- .../scala/spark/storage/BlockManagerWorker.scala | 2 +- core/src/main/scala/spark/storage/BlockStore.scala | 30 ++++++++++------------ 2 files changed, 15 insertions(+), 17 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala index d74cdb38a8..0658a57187 100644 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala @@ -73,7 +73,7 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging { logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) blockManager.putBytes(id, bytes, level) logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.array().length) + + " with data size: " + bytes.limit) } private def getBlock(id: String): ByteBuffer = { diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala index 17f4f51aa8..77e0ed84c5 100644 --- a/core/src/main/scala/spark/storage/BlockStore.scala +++ b/core/src/main/scala/spark/storage/BlockStore.scala @@ -76,11 +76,11 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) currentMemory += sizeEstimate logDebug("Block " + blockId + " stored as values to memory") } else { - val entry = new Entry(bytes, bytes.array().length, false) - ensureFreeSpace(bytes.array.length) + val entry = new Entry(bytes, bytes.limit, false) + ensureFreeSpace(bytes.limit) memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += bytes.array().length - logDebug("Block " + blockId + " stored as " + bytes.array().length + " bytes to memory") + currentMemory += bytes.limit + logDebug("Block " + blockId + " stored as " + bytes.limit + " bytes to memory") } } @@ -97,11 +97,11 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) return Left(elements.iterator) } else { val bytes = dataSerialize(values) - ensureFreeSpace(bytes.array().length) - val entry = new Entry(bytes, bytes.array().length, false) + ensureFreeSpace(bytes.limit) + val entry = new Entry(bytes, bytes.limit, false) memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += bytes.array().length - logDebug("Block " + blockId + " stored as " + bytes.array.length + " bytes to memory") + currentMemory += bytes.limit + logDebug("Block " + blockId + " stored as " + bytes.limit + " bytes to memory") return Right(bytes) } } @@ -118,7 +118,7 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (entry.deserialized) { return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].toIterator) } else { - return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer])) + return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer].duplicate())) } } @@ -199,11 +199,11 @@ class DiskStore(blockManager: BlockManager, rootDirs: String) val file = createFile(blockId) if (file != null) { val channel = new RandomAccessFile(file, "rw").getChannel() - val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.array.length) - buffer.put(bytes.array) + val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.limit) + buffer.put(bytes) channel.close() val finishTime = System.currentTimeMillis - logDebug("Block " + blockId + " stored to file of " + bytes.array.length + " bytes to disk in " + (finishTime - startTime) + " ms") + logDebug("Block " + blockId + " stored to file of " + bytes.limit + " bytes to disk in " + (finishTime - startTime) + " ms") } else { logError("File not created for block " + blockId) } @@ -211,7 +211,7 @@ class DiskStore(blockManager: BlockManager, rootDirs: String) def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = { val bytes = dataSerialize(values) - logDebug("Converted block " + blockId + " to " + bytes.array.length + " bytes") + logDebug("Converted block " + blockId + " to " + bytes.limit + " bytes") putBytes(blockId, bytes, level) return Right(bytes) } @@ -220,9 +220,7 @@ class DiskStore(blockManager: BlockManager, rootDirs: String) val file = getFile(blockId) val length = file.length().toInt val channel = new RandomAccessFile(file, "r").getChannel() - val bytes = ByteBuffer.allocate(length) - bytes.put(channel.map(MapMode.READ_WRITE, 0, length)) - return Some(bytes) + Some(channel.map(MapMode.READ_WRITE, 0, length)) } def getValues(blockId: String): Option[Iterator[Any]] = { -- cgit v1.2.3 From 44758aa8e2337364610ee80fa9ec913301712078 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 1 Sep 2012 00:17:59 -0700 Subject: First work towards a RawInputDStream and a sender program for it. --- .../src/main/scala/spark/DaemonThreadFactory.scala | 12 +- .../scala/spark/util/RateLimitedOutputStream.scala | 56 +++++ .../spark/streaming/NetworkInputDStream.scala | 33 +-- .../spark/streaming/NetworkInputReceiver.scala | 248 --------------------- .../streaming/NetworkInputReceiverMessage.scala | 7 + .../spark/streaming/NetworkInputTracker.scala | 7 +- .../scala/spark/streaming/ObjectInputDStream.scala | 16 ++ .../spark/streaming/ObjectInputReceiver.scala | 244 ++++++++++++++++++++ .../scala/spark/streaming/RawInputDStream.scala | 114 ++++++++++ .../scala/spark/streaming/StreamingContext.scala | 8 +- .../scala/spark/streaming/util/RawTextSender.scala | 51 +++++ 11 files changed, 512 insertions(+), 284 deletions(-) create mode 100644 core/src/main/scala/spark/util/RateLimitedOutputStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala create mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala create mode 100644 streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala create mode 100644 streaming/src/main/scala/spark/streaming/RawInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/RawTextSender.scala (limited to 'core') diff --git a/core/src/main/scala/spark/DaemonThreadFactory.scala b/core/src/main/scala/spark/DaemonThreadFactory.scala index 003880c5e8..56e59adeb7 100644 --- a/core/src/main/scala/spark/DaemonThreadFactory.scala +++ b/core/src/main/scala/spark/DaemonThreadFactory.scala @@ -6,9 +6,13 @@ import java.util.concurrent.ThreadFactory * A ThreadFactory that creates daemon threads */ private object DaemonThreadFactory extends ThreadFactory { - override def newThread(r: Runnable): Thread = { - val t = new Thread(r) - t.setDaemon(true) - return t + override def newThread(r: Runnable): Thread = new DaemonThread(r) +} + +private class DaemonThread(r: Runnable = null) extends Thread { + override def run() { + if (r != null) { + r.run() + } } } \ No newline at end of file diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala new file mode 100644 index 0000000000..10f2272707 --- /dev/null +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -0,0 +1,56 @@ +package spark.util + +import java.io.OutputStream + +class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { + var lastSyncTime = System.nanoTime() + var bytesWrittenSinceSync: Long = 0 + + override def write(b: Int) { + waitToWrite(1) + out.write(b) + } + + override def write(bytes: Array[Byte]) { + write(bytes, 0, bytes.length) + } + + override def write(bytes: Array[Byte], offset: Int, length: Int) { + val CHUNK_SIZE = 8192 + var pos = 0 + while (pos < length) { + val writeSize = math.min(length - pos, CHUNK_SIZE) + waitToWrite(writeSize) + out.write(bytes, offset + pos, length - pos) + pos += writeSize + } + } + + def waitToWrite(numBytes: Int) { + while (true) { + val now = System.nanoTime() + val elapsed = math.max(now - lastSyncTime, 1) + val rate = bytesWrittenSinceSync.toDouble / (elapsed / 1.0e9) + if (rate < bytesPerSec) { + // It's okay to write; just update some variables and return + bytesWrittenSinceSync += numBytes + if (now > lastSyncTime + (1e10).toLong) { + // Ten seconds have passed since lastSyncTime; let's resync + lastSyncTime = now + bytesWrittenSinceSync = numBytes + } + return + } else { + Thread.sleep(5) + } + } + } + + override def flush() { + out.flush() + } + + override def close() { + out.close() + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index ee09324c8c..bf83f98ec4 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -1,36 +1,23 @@ package spark.streaming -import akka.actor._ -import akka.pattern.ask -import akka.util.duration._ -import akka.dispatch._ - import spark.RDD import spark.BlockRDD -import spark.Logging - -import java.io.InputStream +abstract class NetworkInputDStream[T: ClassManifest](@transient ssc: StreamingContext) + extends InputDStream[T](ssc) { -class NetworkInputDStream[T: ClassManifest]( - @transient ssc: StreamingContext, - val host: String, - val port: Int, - val bytesToObjects: InputStream => Iterator[T] - ) extends InputDStream[T](ssc) with Logging { - val id = ssc.getNewNetworkStreamId() - def start() { } + def start() {} - def stop() { } + def stop() {} override def compute(validTime: Time): Option[RDD[T]] = { val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) - return Some(new BlockRDD[T](ssc.sc, blockIds)) + Some(new BlockRDD[T](ssc.sc, blockIds)) } - - def createReceiver(): NetworkInputReceiver[T] = { - new NetworkInputReceiver(id, host, port, bytesToObjects) - } -} \ No newline at end of file + + /** Called on workers to run a receiver for the stream. */ + def runReceiver(): Unit +} + diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala b/streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala deleted file mode 100644 index 7add6246b7..0000000000 --- a/streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala +++ /dev/null @@ -1,248 +0,0 @@ -package spark.streaming - -import spark.Logging -import spark.storage.BlockManager -import spark.storage.StorageLevel -import spark.SparkEnv -import spark.streaming.util.SystemClock -import spark.streaming.util.RecurringTimer - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Queue -import scala.collection.mutable.SynchronizedPriorityQueue -import scala.math.Ordering - -import java.net.InetSocketAddress -import java.net.Socket -import java.io.InputStream -import java.io.BufferedInputStream -import java.io.DataInputStream -import java.io.EOFException -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.ArrayBlockingQueue - -import akka.actor._ -import akka.pattern.ask -import akka.util.duration._ -import akka.dispatch._ - -trait NetworkInputReceiverMessage -case class GetBlockIds(time: Long) extends NetworkInputReceiverMessage -case class GotBlockIds(streamId: Int, blocksIds: Array[String]) extends NetworkInputReceiverMessage -case class StopReceiver() extends NetworkInputReceiverMessage - -class NetworkInputReceiver[T: ClassManifest](streamId: Int, host: String, port: Int, bytesToObjects: InputStream => Iterator[T]) -extends Logging { - - class ReceiverActor extends Actor { - override def preStart() = { - logInfo("Attempting to register") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "NetworkInputTracker" - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - val trackerActor = env.actorSystem.actorFor(url) - val timeout = 100.milliseconds - val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) - } - - def receive = { - case GetBlockIds(time) => { - logInfo("Got request for block ids for " + time) - sender ! GotBlockIds(streamId, dataHandler.getPushedBlocks()) - } - - case StopReceiver() => { - if (receivingThread != null) { - receivingThread.interrupt() - } - sender ! true - } - } - } - - class DataHandler { - - class Block(val time: Long, val iterator: Iterator[T]) { - val blockId = "input-" + streamId + "-" + time - var pushed = true - override def toString() = "input block " + blockId - } - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockOrdering = new Ordering[Block] { - def compare(b1: Block, b2: Block) = (b1.time - b2.time).toInt - } - val blockStorageLevel = StorageLevel.DISK_AND_MEMORY - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blocksForReporting = new SynchronizedPriorityQueue[Block]()(blockOrdering) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - } - - def += (obj: T) { - currentBuffer += obj - } - - def updateCurrentBuffer(time: Long) { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val newBlock = new Block(time - blockInterval, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - blocksForReporting.enqueue(newBlock) - } - } - - def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - if (blockManager != null) { - blockManager.put(block.blockId, block.iterator, blockStorageLevel) - block.pushed = true - } else { - logWarning(block + " not put as block manager is null") - } - } - } catch { - case ie: InterruptedException => println("Block pushing thread interrupted") - case e: Exception => e.printStackTrace() - } - } - - def getPushedBlocks(): Array[String] = { - val pushedBlocks = new ArrayBuffer[String]() - var loop = true - while(loop && !blocksForReporting.isEmpty) { - val block = blocksForReporting.dequeue() - if (block == null) { - loop = false - } else if (!block.pushed) { - blocksForReporting.enqueue(block) - } else { - pushedBlocks += block.blockId - } - } - logInfo("Got " + pushedBlocks.size + " blocks") - pushedBlocks.toArray - } - } - - val blockManager = if (SparkEnv.get != null) SparkEnv.get.blockManager else null - val dataHandler = new DataHandler() - val env = SparkEnv.get - - var receiverActor: ActorRef = null - var receivingThread: Thread = null - - def run() { - initLogging() - var socket: Socket = null - try { - if (SparkEnv.get != null) { - receiverActor = SparkEnv.get.actorSystem.actorOf(Props(new ReceiverActor), "ReceiverActor-" + streamId) - } - dataHandler.start() - socket = connect() - receivingThread = Thread.currentThread() - receive(socket) - } catch { - case ie: InterruptedException => logInfo("Receiver interrupted") - } finally { - receivingThread = null - if (socket != null) socket.close() - dataHandler.stop() - } - } - - def connect(): Socket = { - logInfo("Connecting to " + host + ":" + port) - val socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) - socket - } - - def receive(socket: Socket) { - val iterator = bytesToObjects(socket.getInputStream()) - while(iterator.hasNext) { - val obj = iterator.next - dataHandler += obj - } - } -} - - -object NetworkInputReceiver { - - def bytesToLines(inputStream: InputStream): Iterator[String] = { - val bufferedInputStream = new BufferedInputStream(inputStream) - val dataInputStream = new DataInputStream(bufferedInputStream) - - val iterator = new Iterator[String] { - var gotNext = false - var finished = false - var nextValue: String = null - - private def getNext() { - try { - nextValue = dataInputStream.readLine() - println("[" + nextValue + "]") - } catch { - case eof: EOFException => - finished = true - } - gotNext = true - } - - override def hasNext: Boolean = { - if (!gotNext) { - getNext() - } - if (finished) { - dataInputStream.close() - } - !finished - } - - - override def next(): String = { - if (!gotNext) { - getNext() - } - if (finished) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue - } - } - iterator - } - - def main(args: Array[String]) { - if (args.length < 2) { - println("NetworkReceiver ") - System.exit(1) - } - val host = args(0) - val port = args(1).toInt - val receiver = new NetworkInputReceiver(0, host, port, bytesToLines) - receiver.run() - } -} diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala b/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala new file mode 100644 index 0000000000..deaffe98c8 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala @@ -0,0 +1,7 @@ +package spark.streaming + +sealed trait NetworkInputReceiverMessage + +case class GetBlockIds(time: Long) extends NetworkInputReceiverMessage +case class GotBlockIds(streamId: Int, blocksIds: Array[String]) extends NetworkInputReceiverMessage +case object StopReceiver extends NetworkInputReceiverMessage diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 07758665c9..acf97c1883 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -52,9 +52,7 @@ extends Logging { if (!iterator.hasNext) { throw new Exception("Could not start receiver as details not found.") } - val stream = iterator.next - val receiver = stream.createReceiver() - receiver.run() + iterator.next().runReceiver() } ssc.sc.runJob(tempRDD, startReceiver) @@ -62,8 +60,7 @@ extends Logging { def stopReceivers() { implicit val ec = env.actorSystem.dispatcher - val message = new StopReceiver() - val listOfFutures = receiverInfo.values.map(_.ask(message)(timeout)).toList + val listOfFutures = receiverInfo.values.map(_.ask(StopReceiver)(timeout)).toList val futureOfList = Future.sequence(listOfFutures) Await.result(futureOfList, timeout) } diff --git a/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala b/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala new file mode 100644 index 0000000000..2396b374a0 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala @@ -0,0 +1,16 @@ +package spark.streaming + +import java.io.InputStream + +class ObjectInputDStream[T: ClassManifest]( + @transient ssc: StreamingContext, + val host: String, + val port: Int, + val bytesToObjects: InputStream => Iterator[T]) + extends NetworkInputDStream[T](ssc) { + + override def runReceiver() { + new ObjectInputReceiver(id, host, port, bytesToObjects).run() + } +} + diff --git a/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala b/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala new file mode 100644 index 0000000000..70fa2cdf07 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala @@ -0,0 +1,244 @@ +package spark.streaming + +import spark.Logging +import spark.storage.BlockManager +import spark.storage.StorageLevel +import spark.SparkEnv +import spark.streaming.util.SystemClock +import spark.streaming.util.RecurringTimer + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.Queue +import scala.collection.mutable.SynchronizedPriorityQueue +import scala.math.Ordering + +import java.net.InetSocketAddress +import java.net.Socket +import java.io.InputStream +import java.io.BufferedInputStream +import java.io.DataInputStream +import java.io.EOFException +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.ArrayBlockingQueue + +import akka.actor._ +import akka.pattern.ask +import akka.util.duration._ +import akka.dispatch._ + +class ObjectInputReceiver[T: ClassManifest]( + streamId: Int, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T]) + extends Logging { + + class ReceiverActor extends Actor { + override def preStart() { + logInfo("Attempting to register") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val actorName: String = "NetworkInputTracker" + val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) + val trackerActor = env.actorSystem.actorFor(url) + val timeout = 1.seconds + val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + def receive = { + case GetBlockIds(time) => { + logInfo("Got request for block ids for " + time) + sender ! GotBlockIds(streamId, dataHandler.getPushedBlocks()) + } + + case StopReceiver => { + if (receivingThread != null) { + receivingThread.interrupt() + } + sender ! true + } + } + } + + class DataHandler { + class Block(val time: Long, val iterator: Iterator[T]) { + val blockId = "input-" + streamId + "-" + time + var pushed = true + override def toString = "input block " + blockId + } + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockOrdering = new Ordering[Block] { + def compare(b1: Block, b2: Block) = (b1.time - b2.time).toInt + } + val blockStorageLevel = StorageLevel.DISK_AND_MEMORY + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blocksForReporting = new SynchronizedPriorityQueue[Block]()(blockOrdering) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val newBlock = new Block(time - blockInterval, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + blocksForReporting.enqueue(newBlock) + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + if (blockManager != null) { + blockManager.put(block.blockId, block.iterator, blockStorageLevel) + block.pushed = true + } else { + logWarning(block + " not put as block manager is null") + } + } + } catch { + case ie: InterruptedException => println("Block pushing thread interrupted") + case e: Exception => e.printStackTrace() + } + } + + def getPushedBlocks(): Array[String] = { + val pushedBlocks = new ArrayBuffer[String]() + var loop = true + while(loop && !blocksForReporting.isEmpty) { + val block = blocksForReporting.dequeue() + if (block == null) { + loop = false + } else if (!block.pushed) { + blocksForReporting.enqueue(block) + } else { + pushedBlocks += block.blockId + } + } + logInfo("Got " + pushedBlocks.size + " blocks") + pushedBlocks.toArray + } + } + + val blockManager = if (SparkEnv.get != null) SparkEnv.get.blockManager else null + val dataHandler = new DataHandler() + val env = SparkEnv.get + + var receiverActor: ActorRef = null + var receivingThread: Thread = null + + def run() { + initLogging() + var socket: Socket = null + try { + if (SparkEnv.get != null) { + receiverActor = SparkEnv.get.actorSystem.actorOf(Props(new ReceiverActor), "ReceiverActor-" + streamId) + } + dataHandler.start() + socket = connect() + receivingThread = Thread.currentThread() + receive(socket) + } catch { + case ie: InterruptedException => logInfo("Receiver interrupted") + } finally { + receivingThread = null + if (socket != null) socket.close() + dataHandler.stop() + } + } + + def connect(): Socket = { + logInfo("Connecting to " + host + ":" + port) + val socket = new Socket(host, port) + logInfo("Connected to " + host + ":" + port) + socket + } + + def receive(socket: Socket) { + val iterator = bytesToObjects(socket.getInputStream()) + while(iterator.hasNext) { + val obj = iterator.next + dataHandler += obj + } + } +} + + +object ObjectInputReceiver { + def bytesToLines(inputStream: InputStream): Iterator[String] = { + val bufferedInputStream = new BufferedInputStream(inputStream) + val dataInputStream = new DataInputStream(bufferedInputStream) + + val iterator = new Iterator[String] { + var gotNext = false + var finished = false + var nextValue: String = null + + private def getNext() { + try { + nextValue = dataInputStream.readLine() + println("[" + nextValue + "]") + } catch { + case eof: EOFException => + finished = true + } + gotNext = true + } + + override def hasNext: Boolean = { + if (!gotNext) { + getNext() + } + if (finished) { + dataInputStream.close() + } + !finished + } + + override def next(): String = { + if (!gotNext) { + getNext() + } + if (finished) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } + } + iterator + } + + def main(args: Array[String]) { + if (args.length < 2) { + println("ObjectInputReceiver ") + System.exit(1) + } + val host = args(0) + val port = args(1).toInt + val receiver = new ObjectInputReceiver(0, host, port, bytesToLines) + receiver.run() + } +} diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala new file mode 100644 index 0000000000..49e4781e75 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -0,0 +1,114 @@ +package spark.streaming + +import akka.actor._ +import akka.pattern.ask +import akka.util.duration._ +import akka.dispatch._ +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, SocketChannel} +import java.io.EOFException +import java.util.concurrent.ArrayBlockingQueue +import scala.collection.mutable.ArrayBuffer +import spark.{DaemonThread, Logging, SparkEnv} +import spark.storage.StorageLevel + +/** + * An input stream that reads blocks of serialized objects from a given network address. + * The blocks will be inserted directly into the block store. This is the fastest way to get + * data into Spark Streaming, though it requires the sender to batch data and serialize it + * in the format that the system is configured with. + */ +class RawInputDStream[T: ClassManifest]( + @transient ssc: StreamingContext, + host: String, + port: Int) + extends NetworkInputDStream[T](ssc) with Logging { + + val streamId = id + + /** Called on workers to run a receiver for the stream. */ + def runReceiver() { + val env = SparkEnv.get + val actor = env.actorSystem.actorOf( + Props(new ReceiverActor(env, Thread.currentThread)), "ReceiverActor-" + streamId) + + // Open a socket to the target address and keep reading from it + logInfo("Connecting to " + host + ":" + port) + val channel = SocketChannel.open() + channel.configureBlocking(true) + channel.connect(new InetSocketAddress(host, port)) + logInfo("Connected to " + host + ":" + port) + + val queue = new ArrayBlockingQueue[ByteBuffer](2) + + new DaemonThread { + override def run() { + var nextBlockNumber = 0 + while (true) { + val buffer = queue.take() + val blockId = "input-" + streamId + "-" + nextBlockNumber + nextBlockNumber += 1 + env.blockManager.putBytes(blockId, buffer, StorageLevel.MEMORY_ONLY_2) + actor ! BlockPublished(blockId) + } + } + }.start() + + val lengthBuffer = ByteBuffer.allocate(4) + while (true) { + lengthBuffer.clear() + readFully(channel, lengthBuffer) + lengthBuffer.flip() + val length = lengthBuffer.getInt() + val dataBuffer = ByteBuffer.allocate(length) + readFully(channel, dataBuffer) + dataBuffer.flip() + logInfo("Read a block with " + length + " bytes") + queue.put(dataBuffer) + } + } + + /** Read a buffer fully from a given Channel */ + private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { + while (dest.position < dest.limit) { + if (channel.read(dest) == -1) { + throw new EOFException("End of channel") + } + } + } + + /** Message sent to ReceiverActor to tell it that a block was published */ + case class BlockPublished(blockId: String) {} + + /** A helper actor that communicates with the NetworkInputTracker */ + private class ReceiverActor(env: SparkEnv, receivingThread: Thread) extends Actor { + val newBlocks = new ArrayBuffer[String] + + override def preStart() { + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val actorName: String = "NetworkInputTracker" + val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) + val trackerActor = env.actorSystem.actorFor(url) + val timeout = 1.seconds + val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + override def receive = { + case BlockPublished(blockId) => + newBlocks += blockId + + case GetBlockIds(time) => + logInfo("Got request for block IDs for " + time) + sender ! GotBlockIds(streamId, newBlocks.toArray) + newBlocks.clear() + + case StopReceiver => + receivingThread.interrupt() + sender ! true + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 0ac86cbdf2..feb769e036 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -52,22 +52,22 @@ class StreamingContext ( private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() def createNetworkTextStream(hostname: String, port: Int): DStream[String] = { - createNetworkStream[String](hostname, port, NetworkInputReceiver.bytesToLines) + createNetworkObjectStream[String](hostname, port, ObjectInputReceiver.bytesToLines) } - def createNetworkStream[T: ClassManifest]( + def createNetworkObjectStream[T: ClassManifest]( hostname: String, port: Int, converter: (InputStream) => Iterator[T] ): DStream[T] = { - val inputStream = new NetworkInputDStream[T](this, hostname, port, converter) + val inputStream = new ObjectInputDStream[T](this, hostname, port, converter) inputStreams += inputStream inputStream } /* def createHttpTextStream(url: String): DStream[String] = { - createHttpStream(url, NetworkInputReceiver.bytesToLines) + createHttpStream(url, ObjectInputReceiver.bytesToLines) } def createHttpStream[T: ClassManifest]( diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala new file mode 100644 index 0000000000..60d5849d71 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -0,0 +1,51 @@ +package spark.streaming.util + +import spark.util.{RateLimitedOutputStream, IntParam} +import java.net.ServerSocket +import spark.{Logging, KryoSerializer} +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import io.Source +import java.io.IOException + +/** + * A helper program that sends blocks of Kryo-serialized text strings out on a socket at a + * specified rate. Used to feed data into RawInputDStream. + */ +object RawTextSender extends Logging { + def main(args: Array[String]) { + if (args.length != 4) { + System.err.println("Usage: RawTextSender ") + } + // Parse the arguments using a pattern match + val Array(IntParam(port), file, IntParam(blockSize), IntParam(bytesPerSec)) = args + + // Repeat the input data multiple times to fill in a buffer + val lines = Source.fromFile(file).getLines().toArray + val bufferStream = new FastByteArrayOutputStream(blockSize + 1000) + val ser = new KryoSerializer().newInstance() + val serStream = ser.serializeStream(bufferStream) + var i = 0 + while (bufferStream.position < blockSize) { + serStream.writeObject(lines(i)) + i = (i + 1) % lines.length + } + bufferStream.trim() + val array = bufferStream.array + + val serverSocket = new ServerSocket(port) + + while (true) { + val socket = serverSocket.accept() + val out = new RateLimitedOutputStream(socket.getOutputStream, bytesPerSec) + try { + while (true) { + out.write(array) + } + } catch { + case e: IOException => + logError("Socket closed: ", e) + socket.close() + } + } + } +} -- cgit v1.2.3 From f84d2bbe55aaf3ef7a6631b9018a573aa5729ff7 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 1 Sep 2012 00:31:15 -0700 Subject: Bug fixes to RateLimitedOutputStream --- core/src/main/scala/spark/util/RateLimitedOutputStream.scala | 2 +- streaming/src/main/scala/spark/streaming/util/RawTextSender.scala | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index 10f2272707..d11ed163ce 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -21,7 +21,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu while (pos < length) { val writeSize = math.min(length - pos, CHUNK_SIZE) waitToWrite(writeSize) - out.write(bytes, offset + pos, length - pos) + out.write(bytes, offset + pos, writeSize) pos += writeSize } } diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala index 60d5849d71..85927c02ec 100644 --- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -15,6 +15,7 @@ object RawTextSender extends Logging { def main(args: Array[String]) { if (args.length != 4) { System.err.println("Usage: RawTextSender ") + System.exit(1) } // Parse the arguments using a pattern match val Array(IntParam(port), file, IntParam(blockSize), IntParam(bytesPerSec)) = args @@ -36,6 +37,7 @@ object RawTextSender extends Logging { while (true) { val socket = serverSocket.accept() + logInfo("Got a new connection") val out = new RateLimitedOutputStream(socket.getOutputStream, bytesPerSec) try { while (true) { @@ -43,7 +45,7 @@ object RawTextSender extends Logging { } } catch { case e: IOException => - logError("Socket closed: ", e) + logError("Socket closed", e) socket.close() } } -- cgit v1.2.3 From 6025889be0ecf1c9849c5c940a7171c6d82be0b5 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 1 Sep 2012 20:51:07 +0000 Subject: More raw network receiver programs --- .../mesos/CoarseMesosSchedulerBackend.scala | 4 ++- .../scala/spark/streaming/examples/GrepRaw.scala | 33 +++++++++++++++++ .../spark/streaming/examples/WordCount2.scala | 2 +- .../spark/streaming/examples/WordCountRaw.scala | 42 ++++++++++++++++++++++ 4 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index 31784985dc..fdf007ffb2 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -80,6 +80,8 @@ class CoarseMesosSchedulerBackend( "property, the SPARK_HOME environment variable or the SparkContext constructor") } + val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt + var nextMesosTaskId = 0 def newMesosTaskId(): Int = { @@ -177,7 +179,7 @@ class CoarseMesosSchedulerBackend( val task = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse)) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) .addResources(createResource("mem", executorMemory)) diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala new file mode 100644 index 0000000000..cc52da7bd4 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -0,0 +1,33 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel +import spark.streaming._ +import spark.streaming.StreamingContext._ + +object GrepRaw { + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("Usage: GrepRaw ") + System.exit(1) + } + + val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "GrepRaw") + ssc.setBatchDuration(Milliseconds(batchMillis)) + + // Make sure some tasks have started on each node + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + + val rawStreams = (1 to numStreams).map(_ => + ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + val union = new UnifiedDStream(rawStreams) + union.filter(_.contains("Culpepper")).count().foreachRDD(r => + println("Grep count: " + r.collect().mkString)) + ssc.start() + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index ce553758a7..8c2724e97c 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -100,7 +100,7 @@ object WordCount2 { .reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt) windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Milliseconds(chkptMillis.toLong)) - windowedCounts.print() + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala new file mode 100644 index 0000000000..298d9ef381 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -0,0 +1,42 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel +import spark.streaming._ +import spark.streaming.StreamingContext._ + +object WordCountRaw { + def main(args: Array[String]) { + if (args.length != 7) { + System.err.println("Usage: WordCountRaw ") + System.exit(1) + } + + val Array(master, IntParam(streams), host, IntParam(port), IntParam(batchMs), + IntParam(chkptMs), IntParam(reduces)) = args + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "WordCountRaw") + ssc.setBatchDuration(Milliseconds(batchMs)) + + // Make sure some tasks have started on each node + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + + val rawStreams = (1 to streams).map(_ => + ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + val union = new UnifiedDStream(rawStreams) + + import WordCount2_ExtraFunctions._ + + val windowedCounts = union.mapPartitions(splitAndCountPartitions) + .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, + Milliseconds(chkptMs)) + //windowedCounts.print() // TODO: something else? + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) + + ssc.start() + } +} -- cgit v1.2.3 From ceabf71257631c9e46f82897e540369b99a6bb57 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 1 Sep 2012 21:52:42 +0000 Subject: tweaks --- core/src/main/scala/spark/storage/StorageLevel.scala | 1 + streaming/src/main/scala/spark/streaming/util/RawTextSender.scala | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index f067a2a6c5..a64393eba7 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -66,6 +66,7 @@ class StorageLevel( object StorageLevel { val NONE = new StorageLevel(false, false, false) val DISK_ONLY = new StorageLevel(true, false, false) + val DISK_ONLY_2 = new StorageLevel(true, false, false, 2) val MEMORY_ONLY = new StorageLevel(false, true, false) val MEMORY_ONLY_2 = new StorageLevel(false, true, false, 2) val MEMORY_ONLY_DESER = new StorageLevel(false, true, true) diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala index 8db651ba19..d8b987ec86 100644 --- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -52,7 +52,7 @@ object RawTextSender extends Logging { } } catch { case e: IOException => - logError("Socket closed", e) + logError("Client disconnected") socket.close() } } -- cgit v1.2.3 From 7c09ad0e04639040864236cf13a9fedff6736b5d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 4 Sep 2012 19:11:49 -0700 Subject: Changed DStream member access permissions from private to protected. Updated StateDStream to checkpoint RDDs and forget lineage. --- core/src/main/scala/spark/RDD.scala | 2 +- .../src/main/scala/spark/streaming/DStream.scala | 16 ++-- .../scala/spark/streaming/QueueInputDStream.scala | 2 +- .../main/scala/spark/streaming/StateDStream.scala | 93 ++++++++++++++++------ .../test/scala/spark/streaming/DStreamSuite.scala | 4 +- 5 files changed, 81 insertions(+), 36 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 3fe8e8a4bf..d28f3593fe 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -94,7 +94,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def getStorageLevel = storageLevel - def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER): RDD[T] = { + def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER_2): RDD[T] = { if (!level.useDisk && level.replication < 2) { throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 9b0115eef6..20f1c4db20 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -41,17 +41,17 @@ extends Logging with Serializable { */ // Variable to store the RDDs generated earlier in time - @transient private val generatedRDDs = new HashMap[Time, RDD[T]] () + @transient protected val generatedRDDs = new HashMap[Time, RDD[T]] () // Variable to be set to the first time seen by the DStream (effective time zero) - private[streaming] var zeroTime: Time = null + protected[streaming] var zeroTime: Time = null // Variable to specify storage level - private var storageLevel: StorageLevel = StorageLevel.NONE + protected var storageLevel: StorageLevel = StorageLevel.NONE // Checkpoint level and checkpoint interval - private var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint - private var checkpointInterval: Time = null + protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint + protected var checkpointInterval: Time = null // Change this RDD's storage level def persist( @@ -84,7 +84,7 @@ extends Logging with Serializable { * the validity of future times is calculated. This method also recursively initializes * its parent DStreams. */ - def initialize(time: Time) { + protected[streaming] def initialize(time: Time) { if (zeroTime == null) { zeroTime = time } @@ -93,7 +93,7 @@ extends Logging with Serializable { } /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ - private def isTimeValid (time: Time): Boolean = { + protected def isTimeValid (time: Time): Boolean = { if (!isInitialized) { throw new Exception (this.toString + " has not been initialized") } else if (time < zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) { @@ -208,7 +208,7 @@ extends Logging with Serializable { new TransformedDStream(this, ssc.sc.clean(transformFunc)) } - private[streaming] def toQueue = { + def toQueue = { val queue = new ArrayBlockingQueue[RDD[T]](10000) this.foreachRDD(rdd => { queue.add(rdd) diff --git a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala index bab48ff954..de30297c7d 100644 --- a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala @@ -7,7 +7,7 @@ import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer class QueueInputDStream[T: ClassManifest]( - ssc: StreamingContext, + @transient ssc: StreamingContext, val queue: Queue[RDD[T]], oneAtATime: Boolean, defaultRDD: RDD[T] diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index f313d8c162..4cb780c006 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -1,10 +1,11 @@ package spark.streaming import spark.RDD +import spark.BlockRDD import spark.Partitioner import spark.MapPartitionsRDD import spark.SparkContext._ - +import spark.storage.StorageLevel class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( parent: DStream[(K, V)], @@ -22,6 +23,47 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife override def slideTime = parent.slideTime + override def getOrCompute(time: Time): Option[RDD[(K, S)]] = { + generatedRDDs.get(time) match { + case Some(oldRDD) => { + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) { + val r = oldRDD + val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index) + val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) { + override val partitioner = oldRDD.partitioner + } + generatedRDDs.update(time, checkpointedRDD) + logInfo("Updated RDD of time " + time + " with its checkpointed version") + Some(checkpointedRDD) + } else { + Some(oldRDD) + } + } + case None => { + if (isTimeValid(time)) { + compute(time) match { + case Some(newRDD) => { + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.persist(checkpointLevel) + logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) + } else if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) + } + generatedRDDs.put(time, newRDD) + Some(newRDD) + } + case None => { + None + } + } + } else { + None + } + } + } + } + override def compute(validTime: Time): Option[RDD[(K, S)]] = { // Try to get the previous state RDD @@ -29,26 +71,27 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife case Some(prevStateRDD) => { // If previous state RDD exists - // Define the function for the mapPartition operation on cogrouped RDD; - // first map the cogrouped tuple to tuples of required type, - // and then apply the update function - val func = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { - val i = iterator.map(t => { - (t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S])) - }) - updateFunc(i) - } - // Try to get the parent RDD parent.getOrCompute(validTime) match { case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on cogrouped RDD; + // first map the cogrouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val mapPartitionFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val i = iterator.map(t => { + (t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S])) + }) + updateFuncLocal(i) + } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) - val stateRDD = new SpecialMapPartitionsRDD(cogroupedRDD, func) - logDebug("Generating state RDD for time " + validTime) + val stateRDD = new SpecialMapPartitionsRDD(cogroupedRDD, mapPartitionFunc) + //logDebug("Generating state RDD for time " + validTime) return Some(stateRDD) } case None => { // If parent RDD does not exist, then return old state RDD - logDebug("Generating state RDD for time " + validTime + " (no change)") + //logDebug("Generating state RDD for time " + validTime + " (no change)") return Some(prevStateRDD) } } @@ -56,23 +99,25 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife case None => { // If previous session RDD does not exist (first input data) - // Define the function for the mapPartition operation on grouped RDD; - // first map the grouped tuple to tuples of required type, - // and then apply the update function - val func = (iterator: Iterator[(K, Seq[V])]) => { - updateFunc(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S]))) - } - // Try to get the parent RDD parent.getOrCompute(validTime) match { case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on grouped RDD; + // first map the grouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val mapPartitionFunc = (iterator: Iterator[(K, Seq[V])]) => { + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S]))) + } + val groupedRDD = parentRDD.groupByKey(partitioner) - val sessionRDD = new SpecialMapPartitionsRDD(groupedRDD, func) - logDebug("Generating state RDD for time " + validTime + " (first)") + val sessionRDD = new SpecialMapPartitionsRDD(groupedRDD, mapPartitionFunc) + //logDebug("Generating state RDD for time " + validTime + " (first)") return Some(sessionRDD) } case None => { // If parent RDD does not exist, then nothing to do! - logDebug("Not generating state RDD (no previous state, no parent)") + //logDebug("Not generating state RDD (no previous state, no parent)") return None } } diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala index 030f351080..fc00952afe 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala @@ -107,12 +107,12 @@ class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { Seq(("a", 3), ("b", 3), ("c", 3)) ) - val updateStateOp =(s: DStream[String]) => { + val updateStateOp = (s: DStream[String]) => { val updateFunc = (values: Seq[Int], state: RichInt) => { var newState = 0 if (values != null) newState += values.reduce(_ + _) if (state != null) newState += state.self - //println("values = " + values + ", state = " + state + ", " + " new state = " + newState) + println("values = " + values + ", state = " + state + ", " + " new state = " + newState) new RichInt(newState) } s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) -- cgit v1.2.3 From 4ea032a142ab7fb44f92b145cc8d850164419ab5 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 05:53:07 +0000 Subject: Some changes to make important log output visible even if we set the logging to WARNING --- .../main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala | 2 +- streaming/src/main/scala/spark/streaming/JobManager.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 83e7c6e036..978b4f2676 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -99,7 +99,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Remove a disconnected slave from the cluster def removeSlave(slaveId: String) { - logInfo("Slave " + slaveId + " disconnected, so removing it") + logWarning("Slave " + slaveId + " disconnected, so removing it") val numCores = freeCores(slaveId) actorToSlaveId -= slaveActor(slaveId) addressToSlaveId -= slaveAddress(slaveId) diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 9bf9251519..230d806a89 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -12,7 +12,7 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { SparkEnv.set(ssc.env) try { val timeTaken = job.run() - logInfo("Total delay: %.5f s for job %s (execution: %.5f s)".format( + println("Total delay: %.5f s for job %s (execution: %.5f s)".format( (System.currentTimeMillis() - job.time) / 1000.0, job.id, timeTaken / 1000.0)) } catch { case e: Exception => -- cgit v1.2.3 From b7ad291ac52896af6cb1d882392f3d6fa0cf3b49 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 07:08:07 +0000 Subject: Tuning Akka for more connections --- core/src/main/scala/spark/util/AkkaUtils.scala | 1 + .../src/main/scala/spark/streaming/examples/WordCount2.scala | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 57d212e4ca..fd64e224d7 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -31,6 +31,7 @@ object AkkaUtils { akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = 1s + akka.remote.netty.execution-pool-size = 10 """.format(host, port)) val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index aa542ba07d..8561e7f079 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -62,10 +62,10 @@ object WordCount2_ExtraFunctions { object WordCount2 { def warmup(sc: SparkContext) { - (0 until 10).foreach {i => - sc.parallelize(1 to 20000000, 1000) + (0 until 3).foreach {i => + sc.parallelize(1 to 20000000, 500) .map(x => (x % 337, x % 1331)) - .reduceByKey(_ + _) + .reduceByKey(_ + _, 100) .count() } } @@ -84,11 +84,11 @@ object WordCount2 { val ssc = new StreamingContext(master, "WordCount2") ssc.setBatchDuration(batchDuration) - //warmup(ssc.sc) + warmup(ssc.sc) val data = ssc.sc.textFile(file, mapTasks.toInt).persist( new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas - println("Data count: " + data.count()) + println("Data count: " + data.map(x => if (x == "") 1 else x.split(" ").size / x.split(" ").size).count()) println("Data count: " + data.count()) println("Data count: " + data.count()) -- cgit v1.2.3 From 75487b2f5a6abedd322520f759b814ec643aea01 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 08:14:50 +0000 Subject: Broadcast the JobConf in HadoopRDD to reduce task sizes --- core/src/main/scala/spark/HadoopRDD.scala | 5 +++-- core/src/main/scala/spark/KryoSerializer.scala | 4 ++++ 2 files changed, 7 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/HadoopRDD.scala b/core/src/main/scala/spark/HadoopRDD.scala index f282a4023b..0befca582d 100644 --- a/core/src/main/scala/spark/HadoopRDD.scala +++ b/core/src/main/scala/spark/HadoopRDD.scala @@ -42,7 +42,8 @@ class HadoopRDD[K, V]( minSplits: Int) extends RDD[(K, V)](sc) { - val serializableConf = new SerializableWritable(conf) + // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it + val confBroadcast = sc.broadcast(new SerializableWritable(conf)) @transient val splits_ : Array[Split] = { @@ -66,7 +67,7 @@ class HadoopRDD[K, V]( val split = theSplit.asInstanceOf[HadoopSplit] var reader: RecordReader[K, V] = null - val conf = serializableConf.value + val conf = confBroadcast.value.value val fmt = createInputFormat(conf) reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 65d0532bd5..3d042b2f11 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -10,6 +10,7 @@ import scala.collection.mutable import com.esotericsoftware.kryo._ import com.esotericsoftware.kryo.{Serializer => KSerializer} import com.esotericsoftware.kryo.serialize.ClassSerializer +import com.esotericsoftware.kryo.serialize.SerializableSerializer import de.javakaffee.kryoserializers.KryoReflectionFactorySupport import spark.storage._ @@ -203,6 +204,9 @@ class KryoSerializer extends Serializer with Logging { kryo.register(classOf[Class[_]], new ClassSerializer(kryo)) kryo.setRegistrationOptional(true) + // Allow sending SerializableWritable + kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer()) + // Register some commonly used Scala singleton objects. Because these // are singletons, we must return the exact same local object when we // deserialize rather than returning a clone as FieldSerializer would. -- cgit v1.2.3 From efc7668d16b2a58f8d074c1cdaeae4b37dae1c9c Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 08:22:57 +0000 Subject: Allow serializing HttpBroadcast through Kryo --- core/src/main/scala/spark/KryoSerializer.scala | 2 ++ 1 file changed, 2 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 3d042b2f11..8a3f565071 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -13,6 +13,7 @@ import com.esotericsoftware.kryo.serialize.ClassSerializer import com.esotericsoftware.kryo.serialize.SerializableSerializer import de.javakaffee.kryoserializers.KryoReflectionFactorySupport +import spark.broadcast._ import spark.storage._ /** @@ -206,6 +207,7 @@ class KryoSerializer extends Serializer with Logging { // Allow sending SerializableWritable kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer()) + kryo.register(classOf[HttpBroadcast[_]], new SerializableSerializer()) // Register some commonly used Scala singleton objects. Because these // are singletons, we must return the exact same local object when we -- cgit v1.2.3 From 3fa0d7f0c9883ab77e89b7bcf70b7b11df9a4184 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 08:28:15 +0000 Subject: Serialize BlockRDD more efficiently --- core/src/main/scala/spark/BlockRDD.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/BlockRDD.scala b/core/src/main/scala/spark/BlockRDD.scala index ea009f0f4f..daabc0d566 100644 --- a/core/src/main/scala/spark/BlockRDD.scala +++ b/core/src/main/scala/spark/BlockRDD.scala @@ -7,7 +7,8 @@ class BlockRDDSplit(val blockId: String, idx: Int) extends Split { } -class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) { +class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) + extends RDD[T](sc) { @transient val splits_ = (0 until blockIds.size).map(i => { -- cgit v1.2.3 From 1d6b36d3c3698090b35d8e7c4f88cac410f9ea01 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 22:26:37 +0000 Subject: Further tuning for network performance --- core/src/main/scala/spark/storage/BlockMessage.scala | 14 +------------- core/src/main/scala/spark/util/AkkaUtils.scala | 3 ++- 2 files changed, 3 insertions(+), 14 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala index 0b2ed69e07..607633c6df 100644 --- a/core/src/main/scala/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/spark/storage/BlockMessage.scala @@ -12,7 +12,7 @@ case class GetBlock(id: String) case class GotBlock(id: String, data: ByteBuffer) case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) -class BlockMessage() extends Logging{ +class BlockMessage() { // Un-initialized: typ = 0 // GetBlock: typ = 1 // GotBlock: typ = 2 @@ -22,8 +22,6 @@ class BlockMessage() extends Logging{ private var data: ByteBuffer = null private var level: StorageLevel = null - initLogging() - def set(getBlock: GetBlock) { typ = BlockMessage.TYPE_GET_BLOCK id = getBlock.id @@ -62,8 +60,6 @@ class BlockMessage() extends Logging{ } id = idBuilder.toString() - logDebug("Set from buffer Result: " + typ + " " + id) - logDebug("Buffer position is " + buffer.position) if (typ == BlockMessage.TYPE_PUT_BLOCK) { val booleanInt = buffer.getInt() @@ -77,23 +73,18 @@ class BlockMessage() extends Logging{ } data.put(buffer) data.flip() - logDebug("Set from buffer Result 2: " + level + " " + data) } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { val dataLength = buffer.getInt() - logDebug("Data length is "+ dataLength) - logDebug("Buffer position is " + buffer.position) data = ByteBuffer.allocate(dataLength) if (dataLength != buffer.remaining) { throw new Exception("Error parsing buffer") } data.put(buffer) data.flip() - logDebug("Set from buffer Result 3: " + data) } val finishTime = System.currentTimeMillis - logDebug("Converted " + id + " from bytebuffer in " + (finishTime - startTime) / 1000.0 + " s") } def set(bufferMsg: BufferMessage) { @@ -145,8 +136,6 @@ class BlockMessage() extends Logging{ buffers += data } - logDebug("Start to log buffers.") - buffers.foreach((x: ByteBuffer) => logDebug("" + x)) /* println() println("BlockMessage: ") @@ -160,7 +149,6 @@ class BlockMessage() extends Logging{ println() */ val finishTime = System.currentTimeMillis - logDebug("Converted " + id + " to buffer message in " + (finishTime - startTime) / 1000.0 + " s") return Message.createBufferMessage(buffers) } diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index fd64e224d7..330bb42e59 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -31,7 +31,8 @@ object AkkaUtils { akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = 1s - akka.remote.netty.execution-pool-size = 10 + akka.remote.netty.execution-pool-size = 4 + akka.actor.default-dispatcher.throughput = 20 """.format(host, port)) val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) -- cgit v1.2.3 From dc68febdce53efea43ee1ab91c05b14b7a5eae30 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 23:06:59 +0000 Subject: User Spark's closure serializer for the ShuffleMapTask cache --- core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 9 ++++----- .../main/scala/spark/scheduler/cluster/ClusterScheduler.scala | 1 + 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 73479bff01..f1eae9bc88 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -26,7 +26,8 @@ object ShuffleMapTask { return old } else { val out = new ByteArrayOutputStream - val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) + val ser = SparkEnv.get.closureSerializer.newInstance + val objOut = ser.serializeStream(new GZIPOutputStream(out)) objOut.writeObject(rdd) objOut.writeObject(dep) objOut.close() @@ -45,10 +46,8 @@ object ShuffleMapTask { } else { val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, loader) - } + val ser = SparkEnv.get.closureSerializer.newInstance + val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]] val tuple = (rdd, dep) diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 5b59479682..20c82ad0fa 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -115,6 +115,7 @@ class ClusterScheduler(sc: SparkContext) */ def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = { synchronized { + SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { slaveIdToHost(o.slaveId) = o.hostname -- cgit v1.2.3 From 215544820fe70274d9dce1410f61e2052b8bc406 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 23:54:04 +0000 Subject: Serialize map output locations more efficiently, and only once, in MapOutputTracker --- core/src/main/scala/spark/MapOutputTracker.scala | 88 +++++++++++++++++++++--- 1 file changed, 80 insertions(+), 8 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index de23eb6f48..cee2391c71 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -1,5 +1,6 @@ package spark +import java.io.{DataInputStream, DataOutputStream, ByteArrayOutputStream, ByteArrayInputStream} import java.util.concurrent.ConcurrentHashMap import akka.actor._ @@ -10,6 +11,7 @@ import akka.util.Duration import akka.util.Timeout import akka.util.duration._ +import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import spark.storage.BlockManagerId @@ -18,12 +20,11 @@ sealed trait MapOutputTrackerMessage case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage case object StopMapOutputTracker extends MapOutputTrackerMessage -class MapOutputTrackerActor(bmAddresses: ConcurrentHashMap[Int, Array[BlockManagerId]]) -extends Actor with Logging { +class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging { def receive = { case GetMapOutputLocations(shuffleId: Int) => logInfo("Asked to get map output locations for shuffle " + shuffleId) - sender ! bmAddresses.get(shuffleId) + sender ! tracker.getSerializedLocations(shuffleId) case StopMapOutputTracker => logInfo("MapOutputTrackerActor stopped!") @@ -39,15 +40,19 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg val timeout = 10.seconds - private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] + var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. private var generation: Long = 0 private var generationLock = new java.lang.Object + // Cache a serialized version of the output locations for each shuffle to send them out faster + var cacheGeneration = generation + val cachedSerializedLocs = new HashMap[Int, Array[Byte]] + var trackerActor: ActorRef = if (isMaster) { - val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(bmAddresses)), name = actorName) + val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName) logInfo("Registered MapOutputTrackerActor actor") actor } else { @@ -134,15 +139,16 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg } // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) - val fetched = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[BlockManagerId]] + val fetchedBytes = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[Byte]] + val fetchedLocs = deserializeLocations(fetchedBytes) logInfo("Got the output locations") - bmAddresses.put(shuffleId, fetched) + bmAddresses.put(shuffleId, fetchedLocs) fetching.synchronized { fetching -= shuffleId fetching.notifyAll() } - return fetched + return fetchedLocs } else { return locs } @@ -181,4 +187,70 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg } } } + + def getSerializedLocations(shuffleId: Int): Array[Byte] = { + var locs: Array[BlockManagerId] = null + var generationGotten: Long = -1 + generationLock.synchronized { + if (generation > cacheGeneration) { + cachedSerializedLocs.clear() + cacheGeneration = generation + } + cachedSerializedLocs.get(shuffleId) match { + case Some(bytes) => + return bytes + case None => + locs = bmAddresses.get(shuffleId) + generationGotten = generation + } + } + // If we got here, we failed to find the serialized locations in the cache, so we pulled + // out a snapshot of the locations as "locs"; let's serialize and return that + val bytes = serializeLocations(locs) + // Add them into the table only if the generation hasn't changed while we were working + generationLock.synchronized { + if (generation == generationGotten) { + cachedSerializedLocs(shuffleId) = bytes + } + } + return bytes + } + + // Serialize an array of map output locations into an efficient byte format so that we can send + // it to reduce tasks. We do this by grouping together the locations by block manager ID. + def serializeLocations(locs: Array[BlockManagerId]): Array[Byte] = { + val out = new ByteArrayOutputStream + val dataOut = new DataOutputStream(out) + dataOut.writeInt(locs.length) + val grouped = locs.zipWithIndex.groupBy(_._1) + dataOut.writeInt(grouped.size) + for ((id, pairs) <- grouped) { + dataOut.writeUTF(id.ip) + dataOut.writeInt(id.port) + dataOut.writeInt(pairs.length) + for ((_, blockIndex) <- pairs) { + dataOut.writeInt(blockIndex) + } + } + dataOut.close() + out.toByteArray + } + + // Opposite of serializeLocations. + def deserializeLocations(bytes: Array[Byte]): Array[BlockManagerId] = { + val dataIn = new DataInputStream(new ByteArrayInputStream(bytes)) + val length = dataIn.readInt() + val array = new Array[BlockManagerId](length) + val numGroups = dataIn.readInt() + for (i <- 0 until numGroups) { + val ip = dataIn.readUTF() + val port = dataIn.readInt() + val id = new BlockManagerId(ip, port) + val numBlocks = dataIn.readInt() + for (j <- 0 until numBlocks) { + array(dataIn.readInt()) = id + } + } + array + } } -- cgit v1.2.3 From 2fa6d999fd92cb7ce828278edcd09eecd1f458c1 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Sep 2012 00:16:39 +0000 Subject: Tuning Akka more --- core/src/main/scala/spark/util/AkkaUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 330bb42e59..df4e23bfd6 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -31,8 +31,8 @@ object AkkaUtils { akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = 1s - akka.remote.netty.execution-pool-size = 4 - akka.actor.default-dispatcher.throughput = 20 + akka.remote.netty.execution-pool-size = 8 + akka.actor.default-dispatcher.throughput = 30 """.format(host, port)) val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) -- cgit v1.2.3 From 9ef90c95f4947e47f7c44f952ff8d294e0932a73 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Sep 2012 00:43:46 +0000 Subject: Bug fix --- core/src/main/scala/spark/MapOutputTracker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index cee2391c71..82c1391345 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -224,7 +224,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg dataOut.writeInt(locs.length) val grouped = locs.zipWithIndex.groupBy(_._1) dataOut.writeInt(grouped.size) - for ((id, pairs) <- grouped) { + for ((id, pairs) <- grouped if id != null) { dataOut.writeUTF(id.ip) dataOut.writeInt(id.port) dataOut.writeInt(pairs.length) -- cgit v1.2.3 From db08a362aae68682f9105f9e5568bc9b9d9faaab Mon Sep 17 00:00:00 2001 From: haoyuan Date: Fri, 7 Sep 2012 02:17:52 +0000 Subject: commit opt for grep scalibility test. --- .../main/scala/spark/storage/BlockManager.scala | 7 +++- .../spark/streaming/NetworkInputTracker.scala | 40 ++++++++++++---------- .../scala/spark/streaming/RawInputDStream.scala | 17 +++++---- .../spark/streaming/examples/WordCountRaw.scala | 19 +++++++--- 4 files changed, 51 insertions(+), 32 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index f2d9499bad..4cdb9710ec 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -509,10 +509,15 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * Replicate block to another node. */ + var firstTime = true + var peers : Seq[BlockManagerId] = null private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) - var peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + if (firstTime) { + peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + firstTime = false; + } for (peer: BlockManagerId <- peers) { val start = System.nanoTime data.rewind() diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index acf97c1883..9f9001e4d5 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -4,6 +4,7 @@ import spark.Logging import spark.SparkEnv import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue import akka.actor._ import akka.pattern.ask @@ -28,6 +29,17 @@ extends Logging { logInfo("Registered receiver for network stream " + streamId) sender ! true } + case GotBlockIds(streamId, blockIds) => { + val tmp = receivedBlockIds.synchronized { + if (!receivedBlockIds.contains(streamId)) { + receivedBlockIds += ((streamId, new Queue[String])) + } + receivedBlockIds(streamId) + } + tmp.synchronized { + tmp ++= blockIds + } + } } } @@ -69,8 +81,8 @@ extends Logging { val networkInputStreamIds = networkInputStreams.map(_.id).toArray val receiverExecutor = new ReceiverExecutor() val receiverInfo = new HashMap[Int, ActorRef] - val receivedBlockIds = new HashMap[Int, Array[String]] - val timeout = 1000.milliseconds + val receivedBlockIds = new HashMap[Int, Queue[String]] + val timeout = 5000.milliseconds var currentTime: Time = null @@ -86,22 +98,12 @@ extends Logging { } def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { - if (currentTime == null || time > currentTime) { - logInfo("Getting block ids from receivers for " + time) - implicit val ec = ssc.env.actorSystem.dispatcher - receivedBlockIds.clear() - val message = new GetBlockIds(time) - val listOfFutures = receiverInfo.values.map( - _.ask(message)(timeout).mapTo[GotBlockIds] - ).toList - val futureOfList = Future.sequence(listOfFutures) - val allBlockIds = Await.result(futureOfList, timeout) - receivedBlockIds ++= allBlockIds.map(x => (x.streamId, x.blocksIds)) - if (receivedBlockIds.size != receiverInfo.size) { - throw new Exception("Unexpected number of the Block IDs received") - } - currentTime = time + val queue = receivedBlockIds.synchronized { + receivedBlockIds.getOrElse(receiverId, new Queue[String]()) + } + val result = queue.synchronized { + queue.dequeueAll(x => true) } - receivedBlockIds.getOrElse(receiverId, Array[String]()) + result.toArray } -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index d59c245a23..d29aea7886 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -86,14 +86,15 @@ class RawInputDStream[T: ClassManifest]( private class ReceiverActor(env: SparkEnv, receivingThread: Thread) extends Actor { val newBlocks = new ArrayBuffer[String] + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val actorName: String = "NetworkInputTracker" + val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) + val trackerActor = env.actorSystem.actorFor(url) + val timeout = 5.seconds + override def preStart() { - logInfo("Attempting to register with tracker") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "NetworkInputTracker" - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - val trackerActor = env.actorSystem.actorFor(url) - val timeout = 1.seconds val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) Await.result(future, timeout) } @@ -101,6 +102,7 @@ class RawInputDStream[T: ClassManifest]( override def receive = { case BlockPublished(blockId) => newBlocks += blockId + val future = trackerActor ! GotBlockIds(streamId, Array(blockId)) case GetBlockIds(time) => logInfo("Got request for block IDs for " + time) @@ -111,5 +113,6 @@ class RawInputDStream[T: ClassManifest]( receivingThread.interrupt() sender ! true } + } } diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index 298d9ef381..9702003805 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -1,11 +1,24 @@ package spark.streaming.examples import spark.util.IntParam +import spark.SparkContext +import spark.SparkContext._ import spark.storage.StorageLevel import spark.streaming._ import spark.streaming.StreamingContext._ +import WordCount2_ExtraFunctions._ + object WordCountRaw { + def moreWarmup(sc: SparkContext) { + (0 until 40).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(_ % 1331).map(_.toString) + .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + .collect() + } + } + def main(args: Array[String]) { if (args.length != 7) { System.err.println("Usage: WordCountRaw ") @@ -20,16 +33,12 @@ object WordCountRaw { ssc.setBatchDuration(Milliseconds(batchMs)) // Make sure some tasks have started on each node - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() + moreWarmup(ssc.sc) val rawStreams = (1 to streams).map(_ => ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray val union = new UnifiedDStream(rawStreams) - import WordCount2_ExtraFunctions._ - val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, -- cgit v1.2.3 From c63a6064584ea19d62e0abcbd3886d7b1e429ea1 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 10 Sep 2012 19:51:27 +0000 Subject: Made NewHadoopRDD broadcast its job configuration (same as HadoopRDD). --- core/src/main/scala/spark/NewHadoopRDD.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/NewHadoopRDD.scala b/core/src/main/scala/spark/NewHadoopRDD.scala index d024d38aa9..14f708a3f8 100644 --- a/core/src/main/scala/spark/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/NewHadoopRDD.scala @@ -28,7 +28,9 @@ class NewHadoopRDD[K, V]( @transient conf: Configuration) extends RDD[(K, V)](sc) { - private val serializableConf = new SerializableWritable(conf) + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + // private val serializableConf = new SerializableWritable(conf) private val jobtrackerId: String = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") @@ -41,7 +43,7 @@ class NewHadoopRDD[K, V]( @transient private val splits_ : Array[Split] = { val inputFormat = inputFormatClass.newInstance - val jobContext = new JobContext(serializableConf.value, jobId) + val jobContext = new JobContext(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Split](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -54,9 +56,9 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Split) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopSplit] - val conf = serializableConf.value + val conf = confBroadcast.value.value val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) - val context = new TaskAttemptContext(serializableConf.value, attemptId) + val context = new TaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance val reader = format.createRecordReader(split.serializableHadoopSplit.value, context) reader.initialize(split.serializableHadoopSplit.value, context) -- cgit v1.2.3 From e95ff45b53bf995d89f1825b9581cc18a083a438 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 13 Oct 2012 20:10:49 -0700 Subject: Implemented checkpointing of StreamingContext and DStream graph. --- core/src/main/scala/spark/SparkContext.scala | 4 +- .../main/scala/spark/streaming/Checkpoint.scala | 92 +++++++++++++++ .../src/main/scala/spark/streaming/DStream.scala | 123 ++++++++++++++------- .../main/scala/spark/streaming/DStreamGraph.scala | 80 ++++++++++++++ .../scala/spark/streaming/FileInputDStream.scala | 59 ++++++---- .../spark/streaming/ReducedWindowedDStream.scala | 80 +++++++------- .../src/main/scala/spark/streaming/Scheduler.scala | 33 +++--- .../main/scala/spark/streaming/StateDStream.scala | 20 ++-- .../scala/spark/streaming/StreamingContext.scala | 109 ++++++++++++------ .../examples/FileStreamWithCheckpoint.scala | 76 +++++++++++++ .../scala/spark/streaming/examples/Grep2.scala | 2 +- .../spark/streaming/examples/WordCount2.scala | 2 +- .../scala/spark/streaming/examples/WordMax2.scala | 2 +- .../spark/streaming/util/RecurringTimer.scala | 19 +++- 14 files changed, 536 insertions(+), 165 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/Checkpoint.scala create mode 100644 streaming/src/main/scala/spark/streaming/DStreamGraph.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bebebe8262..1d5131ad13 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -46,8 +46,8 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend import spark.storage.BlockManagerMaster class SparkContext( - master: String, - frameworkName: String, + val master: String, + val frameworkName: String, val sparkHome: String, val jars: Seq[String]) extends Logging { diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala new file mode 100644 index 0000000000..3bd8fd5a27 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -0,0 +1,92 @@ +package spark.streaming + +import spark.Utils + +import org.apache.hadoop.fs.{FileUtil, Path} +import org.apache.hadoop.conf.Configuration + +import java.io.{ObjectInputStream, ObjectOutputStream} + +class Checkpoint(@transient ssc: StreamingContext) extends Serializable { + val master = ssc.sc.master + val frameworkName = ssc.sc.frameworkName + val sparkHome = ssc.sc.sparkHome + val jars = ssc.sc.jars + val graph = ssc.graph + val batchDuration = ssc.batchDuration + val checkpointFile = ssc.checkpointFile + val checkpointInterval = ssc.checkpointInterval + + def saveToFile(file: String) { + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (fs.exists(path)) { + val bkPath = new Path(path.getParent, path.getName + ".bk") + FileUtil.copy(fs, path, fs, bkPath, true, true, conf) + println("Moved existing checkpoint file to " + bkPath) + } + val fos = fs.create(path) + val oos = new ObjectOutputStream(fos) + oos.writeObject(this) + oos.close() + fs.close() + } + + def toBytes(): Array[Byte] = { + val cp = new Checkpoint(ssc) + val bytes = Utils.serialize(cp) + bytes + } +} + +object Checkpoint { + + def loadFromFile(file: String): Checkpoint = { + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (!fs.exists(path)) { + throw new Exception("Could not read checkpoint file " + path) + } + val fis = fs.open(path) + val ois = new ObjectInputStream(fis) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp + } + + def fromBytes(bytes: Array[Byte]): Checkpoint = { + Utils.deserialize[Checkpoint](bytes) + } + + /*def toBytes(ssc: StreamingContext): Array[Byte] = { + val cp = new Checkpoint(ssc) + val bytes = Utils.serialize(cp) + bytes + } + + + def saveContext(ssc: StreamingContext, file: String) { + val cp = new Checkpoint(ssc) + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (fs.exists(path)) { + val bkPath = new Path(path.getParent, path.getName + ".bk") + FileUtil.copy(fs, path, fs, bkPath, true, true, conf) + println("Moved existing checkpoint file to " + bkPath) + } + val fos = fs.create(path) + val oos = new ObjectOutputStream(fos) + oos.writeObject(cp) + oos.close() + fs.close() + } + + def loadContext(file: String): StreamingContext = { + loadCheckpoint(file).createNewContext() + } + */ +} diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 7e8098c346..78e4c57647 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -2,20 +2,19 @@ package spark.streaming import spark.streaming.StreamingContext._ -import spark.RDD -import spark.UnionRDD -import spark.Logging +import spark._ import spark.SparkContext._ import spark.storage.StorageLevel -import spark.Partitioner import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import java.util.concurrent.ArrayBlockingQueue +import java.io.{ObjectInputStream, IOException, ObjectOutputStream} +import scala.Some -abstract class DStream[T: ClassManifest] (@transient val ssc: StreamingContext) -extends Logging with Serializable { +abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) +extends Serializable with Logging { initLogging() @@ -41,10 +40,10 @@ extends Logging with Serializable { */ // Variable to store the RDDs generated earlier in time - @transient protected val generatedRDDs = new HashMap[Time, RDD[T]] () + protected val generatedRDDs = new HashMap[Time, RDD[T]] () // Variable to be set to the first time seen by the DStream (effective time zero) - protected[streaming] var zeroTime: Time = null + protected var zeroTime: Time = null // Variable to specify storage level protected var storageLevel: StorageLevel = StorageLevel.NONE @@ -53,6 +52,9 @@ extends Logging with Serializable { protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint protected var checkpointInterval: Time = null + // Reference to whole DStream graph, so that checkpointing process can lock it + protected var graph: DStreamGraph = null + // Change this RDD's storage level def persist( storageLevel: StorageLevel, @@ -77,7 +79,7 @@ extends Logging with Serializable { // Turn on the default caching level for this RDD def cache(): DStream[T] = persist() - def isInitialized = (zeroTime != null) + def isInitialized() = (zeroTime != null) /** * This method initializes the DStream by setting the "zero" time, based on which @@ -85,15 +87,33 @@ extends Logging with Serializable { * its parent DStreams. */ protected[streaming] def initialize(time: Time) { - if (zeroTime == null) { - zeroTime = time + if (zeroTime != null) { + throw new Exception("ZeroTime is already initialized, cannot initialize it again") } + zeroTime = time logInfo(this + " initialized") dependencies.foreach(_.initialize(zeroTime)) } + protected[streaming] def setContext(s: StreamingContext) { + if (ssc != null && ssc != s) { + throw new Exception("Context is already set, cannot set it again") + } + ssc = s + logInfo("Set context for " + this.getClass.getSimpleName) + dependencies.foreach(_.setContext(ssc)) + } + + protected[streaming] def setGraph(g: DStreamGraph) { + if (graph != null && graph != g) { + throw new Exception("Graph is already set, cannot set it again") + } + graph = g + dependencies.foreach(_.setGraph(graph)) + } + /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ - protected def isTimeValid (time: Time): Boolean = { + protected def isTimeValid(time: Time): Boolean = { if (!isInitialized) { throw new Exception (this.toString + " has not been initialized") } else if (time < zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) { @@ -158,13 +178,42 @@ extends Logging with Serializable { } } + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + println(this.getClass().getSimpleName + ".writeObject used") + if (graph != null) { + graph.synchronized { + if (graph.checkpointInProgress) { + oos.defaultWriteObject() + } else { + val msg = "Object of " + this.getClass.getName + " is being serialized " + + " possibly as a part of closure of an RDD operation. This is because " + + " the DStream object is being referred to from within the closure. " + + " Please rewrite the RDD operation inside this DStream to avoid this. " + + " This has been enforced to avoid bloating of Spark tasks " + + " with unnecessary objects." + throw new java.io.NotSerializableException(msg) + } + } + } else { + throw new java.io.NotSerializableException("Graph is unexpectedly null when DStream is being serialized.") + } + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + println(this.getClass().getSimpleName + ".readObject used") + ois.defaultReadObject() + } + /** * -------------- * DStream operations * -------------- */ - - def map[U: ClassManifest](mapFunc: T => U) = new MappedDStream(this, ssc.sc.clean(mapFunc)) + def map[U: ClassManifest](mapFunc: T => U) = { + new MappedDStream(this, ssc.sc.clean(mapFunc)) + } def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = { new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) @@ -262,19 +311,15 @@ extends Logging with Serializable { // Get all the RDDs between fromTime to toTime (both included) def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { - val rdds = new ArrayBuffer[RDD[T]]() var time = toTime.floor(slideTime) - - while (time >= zeroTime && time >= fromTime) { getOrCompute(time) match { case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not get old reduced RDD for time " + time) + case None => //throw new Exception("Could not get RDD for time " + time) } time -= slideTime } - rdds.toSeq } @@ -284,12 +329,16 @@ extends Logging with Serializable { } -abstract class InputDStream[T: ClassManifest] (@transient ssc: StreamingContext) - extends DStream[T](ssc) { +abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) + extends DStream[T](ssc_) { override def dependencies = List() - override def slideTime = ssc.batchDuration + override def slideTime = { + if (ssc == null) throw new Exception("ssc is null") + if (ssc.batchDuration == null) throw new Exception("ssc.batchDuration is null") + ssc.batchDuration + } def start() @@ -302,7 +351,7 @@ abstract class InputDStream[T: ClassManifest] (@transient ssc: StreamingContext) */ class MappedDStream[T: ClassManifest, U: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], mapFunc: T => U ) extends DStream[U](parent.ssc) { @@ -321,7 +370,7 @@ class MappedDStream[T: ClassManifest, U: ClassManifest] ( */ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( - @transient parent: DStream[T], + parent: DStream[T], flatMapFunc: T => Traversable[U] ) extends DStream[U](parent.ssc) { @@ -340,7 +389,7 @@ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( */ class FilteredDStream[T: ClassManifest]( - @transient parent: DStream[T], + parent: DStream[T], filterFunc: T => Boolean ) extends DStream[T](parent.ssc) { @@ -359,7 +408,7 @@ class FilteredDStream[T: ClassManifest]( */ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( - @transient parent: DStream[T], + parent: DStream[T], mapPartFunc: Iterator[T] => Iterator[U] ) extends DStream[U](parent.ssc) { @@ -377,7 +426,7 @@ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( * TODO */ -class GlommedDStream[T: ClassManifest](@transient parent: DStream[T]) +class GlommedDStream[T: ClassManifest](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { override def dependencies = List(parent) @@ -395,7 +444,7 @@ class GlommedDStream[T: ClassManifest](@transient parent: DStream[T]) */ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( - @transient parent: DStream[(K,V)], + parent: DStream[(K,V)], createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, @@ -420,7 +469,7 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( * TODO */ -class UnifiedDStream[T: ClassManifest](@transient parents: Array[DStream[T]]) +class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]]) extends DStream[T](parents(0).ssc) { if (parents.length == 0) { @@ -459,7 +508,7 @@ class UnifiedDStream[T: ClassManifest](@transient parents: Array[DStream[T]]) */ class PerElementForEachDStream[T: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], foreachFunc: T => Unit ) extends DStream[Unit](parent.ssc) { @@ -490,7 +539,7 @@ class PerElementForEachDStream[T: ClassManifest] ( */ class PerRDDForEachDStream[T: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], foreachFunc: (RDD[T], Time) => Unit ) extends DStream[Unit](parent.ssc) { @@ -518,15 +567,15 @@ class PerRDDForEachDStream[T: ClassManifest] ( */ class TransformedDStream[T: ClassManifest, U: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], transformFunc: (RDD[T], Time) => RDD[U] ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Time = parent.slideTime - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(transformFunc(_, validTime)) - } + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(transformFunc(_, validTime)) } +} diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala new file mode 100644 index 0000000000..67859e0131 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -0,0 +1,80 @@ +package spark.streaming + +import java.io.{ObjectInputStream, IOException, ObjectOutputStream} +import collection.mutable.ArrayBuffer + +final class DStreamGraph extends Serializable { + + private val inputStreams = new ArrayBuffer[InputDStream[_]]() + private val outputStreams = new ArrayBuffer[DStream[_]]() + + private[streaming] var zeroTime: Time = null + private[streaming] var checkpointInProgress = false; + + def started() = (zeroTime != null) + + def start(time: Time) { + this.synchronized { + if (started) { + throw new Exception("DStream graph computation already started") + } + zeroTime = time + outputStreams.foreach(_.initialize(zeroTime)) + inputStreams.par.foreach(_.start()) + } + + } + + def stop() { + this.synchronized { + inputStreams.par.foreach(_.stop()) + } + } + + private[streaming] def setContext(ssc: StreamingContext) { + this.synchronized { + outputStreams.foreach(_.setContext(ssc)) + } + } + + def addInputStream(inputStream: InputDStream[_]) { + inputStream.setGraph(this) + inputStreams += inputStream + } + + def addOutputStream(outputStream: DStream[_]) { + outputStream.setGraph(this) + outputStreams += outputStream + } + + def getInputStreams() = inputStreams.toArray + + def getOutputStreams() = outputStreams.toArray + + def generateRDDs(time: Time): Seq[Job] = { + this.synchronized { + outputStreams.flatMap(outputStream => outputStream.generateJob(time)) + } + } + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + this.synchronized { + checkpointInProgress = true + oos.defaultWriteObject() + checkpointInProgress = false + } + println("DStreamGraph.writeObject used") + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + this.synchronized { + checkpointInProgress = true + ois.defaultReadObject() + checkpointInProgress = false + } + println("DStreamGraph.readObject used") + } +} + diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala index 96a64f0018..29ae89616e 100644 --- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala @@ -1,33 +1,45 @@ package spark.streaming -import spark.SparkContext import spark.RDD -import spark.BlockRDD import spark.UnionRDD -import spark.storage.StorageLevel -import spark.streaming._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import java.net.InetSocketAddress - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.PathFilter +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import java.io.{ObjectInputStream, IOException} class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( - ssc: StreamingContext, - directory: Path, + @transient ssc_ : StreamingContext, + directory: String, filter: PathFilter = FileInputDStream.defaultPathFilter, newFilesOnly: Boolean = true) - extends InputDStream[(K, V)](ssc) { - - val fs = directory.getFileSystem(new Configuration()) + extends InputDStream[(K, V)](ssc_) { + + @transient private var path_ : Path = null + @transient private var fs_ : FileSystem = null + + /* + @transient @noinline lazy val path = { + //if (directory == null) throw new Exception("directory is null") + //println(directory) + new Path(directory) + } + @transient lazy val fs = path.getFileSystem(new Configuration()) + */ + var lastModTime: Long = 0 - + + def path(): Path = { + if (path_ == null) path_ = new Path(directory) + path_ + } + + def fs(): FileSystem = { + if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + fs_ + } + override def start() { if (newFilesOnly) { lastModTime = System.currentTimeMillis() @@ -58,7 +70,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K } } - val newFiles = fs.listStatus(directory, newFilter) + val newFiles = fs.listStatus(path, newFilter) logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) if (newFiles.length > 0) { lastModTime = newFilter.latestModTime @@ -67,10 +79,19 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) Some(newRDD) } + /* + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + println(this.getClass().getSimpleName + ".readObject used") + ois.defaultReadObject() + println("HERE HERE" + this.directory) + } + */ + } object FileInputDStream { - val defaultPathFilter = new PathFilter { + val defaultPathFilter = new PathFilter with Serializable { def accept(path: Path): Boolean = { val file = path.getName() if (file.startsWith(".") || file.endsWith("_tmp")) { diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index b0beaba94d..e161b5ba92 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -10,9 +10,10 @@ import spark.SparkContext._ import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer +import collection.SeqProxy class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( - @transient parent: DStream[(K, V)], + parent: DStream[(K, V)], reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, _windowTime: Time, @@ -46,6 +47,8 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( } override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val reduceF = reduceFunc + val invReduceF = invReduceFunc val currentTime = validTime val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime) @@ -84,54 +87,47 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( // Cogroup the reduced RDDs and merge the reduced values val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) - val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ - val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValuesFunc) + //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ - Some(mergedValuesRDD) - } - - def mergeValues(numOldValues: Int, numNewValues: Int)(seqOfValues: Seq[Seq[V]]): V = { - - if (seqOfValues.size != 1 + numOldValues + numNewValues) { - throw new Exception("Unexpected number of sequences of reduced values") - } - - // Getting reduced values "old time steps" that will be removed from current window - val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) - - // Getting reduced values "new time steps" - val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) - - if (seqOfValues(0).isEmpty) { + val numOldValues = oldRDDs.size + val numNewValues = newRDDs.size - // If previous window's reduce value does not exist, then at least new values should exist - if (newValues.isEmpty) { - throw new Exception("Neither previous window has value for key, nor new values found") + val mergeValues = (seqOfValues: Seq[Seq[V]]) => { + if (seqOfValues.size != 1 + numOldValues + numNewValues) { + throw new Exception("Unexpected number of sequences of reduced values") } + // Getting reduced values "old time steps" that will be removed from current window + val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) + // Getting reduced values "new time steps" + val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + if (seqOfValues(0).isEmpty) { + // If previous window's reduce value does not exist, then at least new values should exist + if (newValues.isEmpty) { + throw new Exception("Neither previous window has value for key, nor new values found") + } + // Reduce the new values + newValues.reduce(reduceF) // return + } else { + // Get the previous window's reduced value + var tempValue = seqOfValues(0).head + // If old values exists, then inverse reduce then from previous value + if (!oldValues.isEmpty) { + tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) + } + // If new values exists, then reduce them with previous value + if (!newValues.isEmpty) { + tempValue = reduceF(tempValue, newValues.reduce(reduceF)) + } + tempValue // return + } + } - // Reduce the new values - // println("new values = " + newValues.map(_.toString).reduce(_ + " " + _)) - return newValues.reduce(reduceFunc) - } else { + val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) - // Get the previous window's reduced value - var tempValue = seqOfValues(0).head + Some(mergedValuesRDD) + } - // If old values exists, then inverse reduce then from previous value - if (!oldValues.isEmpty) { - // println("old values = " + oldValues.map(_.toString).reduce(_ + " " + _)) - tempValue = invReduceFunc(tempValue, oldValues.reduce(reduceFunc)) - } - // If new values exists, then reduce them with previous value - if (!newValues.isEmpty) { - // println("new values = " + newValues.map(_.toString).reduce(_ + " " + _)) - tempValue = reduceFunc(tempValue, newValues.reduce(reduceFunc)) - } - // println("final value = " + tempValue) - return tempValue - } - } } diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index d2e907378d..d62b7e7140 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -11,45 +11,44 @@ import scala.collection.mutable.HashMap sealed trait SchedulerMessage case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage -class Scheduler( - ssc: StreamingContext, - inputStreams: Array[InputDStream[_]], - outputStreams: Array[DStream[_]]) +class Scheduler(ssc: StreamingContext) extends Logging { initLogging() + val graph = ssc.graph val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] val timer = new RecurringTimer(clock, ssc.batchDuration, generateRDDs(_)) - + + def start() { - val zeroTime = Time(timer.start()) - outputStreams.foreach(_.initialize(zeroTime)) - inputStreams.par.foreach(_.start()) + if (graph.started) { + timer.restart(graph.zeroTime.milliseconds) + } else { + val zeroTime = Time(timer.start()) + graph.start(zeroTime) + } logInfo("Scheduler started") } def stop() { timer.stop() - inputStreams.par.foreach(_.stop()) + graph.stop() logInfo("Scheduler stopped") } - def generateRDDs (time: Time) { + def generateRDDs(time: Time) { SparkEnv.set(ssc.env) logInfo("\n-----------------------------------------------------\n") logInfo("Generating RDDs for time " + time) - outputStreams.foreach(outputStream => { - outputStream.generateJob(time) match { - case Some(job) => submitJob(job) - case None => - } - } - ) + graph.generateRDDs(time).foreach(submitJob) logInfo("Generated RDDs for time " + time) + if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { + ssc.checkpoint() + } } def submitJob(job: Job) { diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index c40f70c91d..d223f25dfc 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -7,6 +7,12 @@ import spark.MapPartitionsRDD import spark.SparkContext._ import spark.storage.StorageLevel + +class StateRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: Iterator[T] => Iterator[U], rememberPartitioner: Boolean) + extends MapPartitionsRDD[U, T](prev, f) { + override val partitioner = if (rememberPartitioner) prev.partitioner else None +} + class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( @transient parent: DStream[(K, V)], updateFunc: (Iterator[(K, Seq[V], S)]) => Iterator[(K, S)], @@ -14,11 +20,6 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife rememberPartitioner: Boolean ) extends DStream[(K, S)](parent.ssc) { - class SpecialMapPartitionsRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: Iterator[T] => Iterator[U]) - extends MapPartitionsRDD(prev, f) { - override val partitioner = if (rememberPartitioner) prev.partitioner else None - } - override def dependencies = List(parent) override def slideTime = parent.slideTime @@ -79,19 +80,18 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife // first map the cogrouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc - val mapPartitionFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { val i = iterator.map(t => { (t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S])) }) updateFuncLocal(i) } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) - val stateRDD = new SpecialMapPartitionsRDD(cogroupedRDD, mapPartitionFunc) + val stateRDD = new StateRDD(cogroupedRDD, finalFunc, rememberPartitioner) //logDebug("Generating state RDD for time " + validTime) return Some(stateRDD) } case None => { // If parent RDD does not exist, then return old state RDD - //logDebug("Generating state RDD for time " + validTime + " (no change)") return Some(prevStateRDD) } } @@ -107,12 +107,12 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife // first map the grouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc - val mapPartitionFunc = (iterator: Iterator[(K, Seq[V])]) => { + val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S]))) } val groupedRDD = parentRDD.groupByKey(partitioner) - val sessionRDD = new SpecialMapPartitionsRDD(groupedRDD, mapPartitionFunc) + val sessionRDD = new StateRDD(groupedRDD, finalFunc, rememberPartitioner) //logDebug("Generating state RDD for time " + validTime + " (first)") return Some(sessionRDD) } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 12f3626680..1499ef4ea2 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -21,31 +21,70 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat -class StreamingContext (@transient val sc: SparkContext) extends Logging { +class StreamingContext ( + sc_ : SparkContext, + cp_ : Checkpoint + ) extends Logging { + + def this(sparkContext: SparkContext) = this(sparkContext, null) def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = - this(new SparkContext(master, frameworkName, sparkHome, jars)) + this(new SparkContext(master, frameworkName, sparkHome, jars), null) + + def this(file: String) = this(null, Checkpoint.loadFromFile(file)) + + def this(cp_ : Checkpoint) = this(null, cp_) initLogging() + if (sc_ == null && cp_ == null) { + throw new Exception("Streaming Context cannot be initilalized with " + + "both SparkContext and checkpoint as null") + } + + val isCheckpointPresent = (cp_ != null) + + val sc: SparkContext = { + if (isCheckpointPresent) { + new SparkContext(cp_.master, cp_.frameworkName, cp_.sparkHome, cp_.jars) + } else { + sc_ + } + } + val env = SparkEnv.get - - val inputStreams = new ArrayBuffer[InputDStream[_]]() - val outputStreams = new ArrayBuffer[DStream[_]]() + + val graph: DStreamGraph = { + if (isCheckpointPresent) { + + cp_.graph.setContext(this) + cp_.graph + } else { + new DStreamGraph() + } + } + val nextNetworkInputStreamId = new AtomicInteger(0) - var batchDuration: Time = null - var scheduler: Scheduler = null + var batchDuration: Time = if (isCheckpointPresent) cp_.batchDuration else null + var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null + var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null var networkInputTracker: NetworkInputTracker = null - var receiverJobThread: Thread = null - - def setBatchDuration(duration: Long) { - setBatchDuration(Time(duration)) - } - + var receiverJobThread: Thread = null + var scheduler: Scheduler = null + def setBatchDuration(duration: Time) { + if (batchDuration != null) { + throw new Exception("Batch duration alread set as " + batchDuration + + ". cannot set it again.") + } batchDuration = duration } + + def setCheckpointDetails(file: String, interval: Time) { + checkpointFile = file + checkpointInterval = interval + } private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() @@ -59,7 +98,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { converter: (InputStream) => Iterator[T] ): DStream[T] = { val inputStream = new ObjectInputDStream[T](this, hostname, port, converter) - inputStreams += inputStream + graph.addInputStream(inputStream) inputStream } @@ -69,7 +108,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_2 ): DStream[T] = { val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel) - inputStreams += inputStream + graph.addInputStream(inputStream) inputStream } @@ -94,8 +133,8 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { V: ClassManifest, F <: NewInputFormat[K, V]: ClassManifest ](directory: String): DStream[(K, V)] = { - val inputStream = new FileInputDStream[K, V, F](this, new Path(directory)) - inputStreams += inputStream + val inputStream = new FileInputDStream[K, V, F](this, directory) + graph.addInputStream(inputStream) inputStream } @@ -113,24 +152,31 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { defaultRDD: RDD[T] = null ): DStream[T] = { val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) - inputStreams += inputStream + graph.addInputStream(inputStream) inputStream } - def createQueueStream[T: ClassManifest](iterator: Iterator[RDD[T]]): DStream[T] = { + def createQueueStream[T: ClassManifest](iterator: Array[RDD[T]]): DStream[T] = { val queue = new Queue[RDD[T]] val inputStream = createQueueStream(queue, true, null) queue ++= iterator inputStream - } + } + + /** + * This function registers a InputDStream as an input stream that will be + * started (InputDStream.start() called) to get the input data streams. + */ + def registerInputStream(inputStream: InputDStream[_]) { + graph.addInputStream(inputStream) + } - /** * This function registers a DStream as an output stream that will be * computed every interval. */ - def registerOutputStream (outputStream: DStream[_]) { - outputStreams += outputStream + def registerOutputStream(outputStream: DStream[_]) { + graph.addOutputStream(outputStream) } /** @@ -143,13 +189,9 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { if (batchDuration < Milliseconds(100)) { logWarning("Batch duration of " + batchDuration + " is very low") } - if (inputStreams.size == 0) { - throw new Exception("No input streams created, so nothing to take input from") - } - if (outputStreams.size == 0) { + if (graph.getOutputStreams().size == 0) { throw new Exception("No output streams registered, so nothing to execute") } - } /** @@ -157,7 +199,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { */ def start() { verify() - val networkInputStreams = inputStreams.filter(s => s match { + val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true case _ => false }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray @@ -169,8 +211,9 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { } Thread.sleep(1000) - // Start the scheduler - scheduler = new Scheduler(this, inputStreams.toArray, outputStreams.toArray) + + // Start the scheduler + scheduler = new Scheduler(this) scheduler.start() } @@ -189,6 +232,10 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { logInfo("StreamingContext stopped") } + + def checkpoint() { + new Checkpoint(this).saveToFile(checkpointFile) + } } diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala new file mode 100644 index 0000000000..c725035a8a --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala @@ -0,0 +1,76 @@ +package spark.streaming.examples + +import spark.streaming._ +import spark.streaming.StreamingContext._ +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + +object FileStreamWithCheckpoint { + + def main(args: Array[String]) { + + if (args.size != 3) { + println("FileStreamWithCheckpoint ") + println("FileStreamWithCheckpoint restart ") + System.exit(-1) + } + + val directory = new Path(args(1)) + val checkpointFile = args(2) + + val ssc: StreamingContext = { + + if (args(0) == "restart") { + + // Recreated streaming context from specified checkpoint file + new StreamingContext(checkpointFile) + + } else { + + // Create directory if it does not exist + val fs = directory.getFileSystem(new Configuration()) + if (!fs.exists(directory)) fs.mkdirs(directory) + + // Create new streaming context + val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint") + ssc_.setBatchDuration(Seconds(1)) + ssc_.setCheckpointDetails(checkpointFile, Seconds(1)) + + // Setup the streaming computation + val inputStream = ssc_.createTextFileStream(directory.toString) + val words = inputStream.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + + ssc_ + } + } + + // Start the stream computation + startFileWritingThread(directory.toString) + ssc.start() + } + + def startFileWritingThread(directory: String) { + + val fs = new Path(directory).getFileSystem(new Configuration()) + + val fileWritingThread = new Thread() { + override def run() { + val r = new scala.util.Random() + val text = "This is a sample text file with a random number " + while(true) { + val number = r.nextInt() + val file = new Path(directory, number.toString) + val fos = fs.create(file) + fos.writeChars(text + number) + fos.close() + println("Created text file " + file) + Thread.sleep(1000) + } + } + } + fileWritingThread.start() + } + +} diff --git a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala index 7237142c7c..b1faa65c17 100644 --- a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala @@ -50,7 +50,7 @@ object Grep2 { println("Data count: " + data.count()) val sentences = new ConstantInputDStream(ssc, data) - ssc.inputStreams += sentences + ssc.registerInputStream(sentences) sentences.filter(_.contains("Culpepper")).count().foreachRDD(r => println("Grep count: " + r.collect().mkString)) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index c22949d7b9..8390f4af94 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -93,7 +93,7 @@ object WordCount2 { println("Data count: " + data.count()) val sentences = new ConstantInputDStream(ssc, data) - ssc.inputStreams += sentences + ssc.registerInputStream(sentences) import WordCount2_ExtraFunctions._ diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala index 3658cb302d..fc7567322b 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala @@ -50,7 +50,7 @@ object WordMax2 { println("Data count: " + data.count()) val sentences = new ConstantInputDStream(ssc, data) - ssc.inputStreams += sentences + ssc.registerInputStream(sentences) import WordCount2_ExtraFunctions._ diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala index 5da9fa6ecc..7f19b26a79 100644 --- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -17,12 +17,23 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => } var nextTime = 0L - - def start(): Long = { - nextTime = (math.floor(clock.currentTime / period) + 1).toLong * period - thread.start() + + def start(startTime: Long): Long = { + nextTime = startTime + thread.start() nextTime } + + def start(): Long = { + val startTime = math.ceil(clock.currentTime / period).toLong * period + start(startTime) + } + + def restart(originalStartTime: Long): Long = { + val gap = clock.currentTime - originalStartTime + val newStartTime = math.ceil(gap / period).toLong * period + originalStartTime + start(newStartTime) + } def stop() { thread.interrupt() -- cgit v1.2.3 From 52989c8a2c8c10d7f5610c033f6782e58fd3abc2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 19 Oct 2012 10:24:49 -0700 Subject: Update Python API for v0.6.0 compatibility. --- core/src/main/scala/spark/api/python/PythonRDD.scala | 18 +++++++++++------- core/src/main/scala/spark/broadcast/Broadcast.scala | 2 +- pyspark/pyspark/broadcast.py | 18 +++++++++--------- pyspark/pyspark/context.py | 2 +- pyspark/pyspark/java_gateway.py | 3 ++- pyspark/pyspark/serializers.py | 18 ++++++++++++++---- pyspark/pyspark/worker.py | 8 ++++---- 7 files changed, 42 insertions(+), 27 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 4d3bdb3963..528885fe5c 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -5,11 +5,15 @@ import java.io._ import scala.collection.Map import scala.collection.JavaConversions._ import scala.io.Source -import spark._ -import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} -import broadcast.Broadcast -import scala.collection -import java.nio.charset.Charset + +import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import spark.broadcast.Broadcast +import spark.SparkEnv +import spark.Split +import spark.RDD +import spark.OneToOneDependency +import spark.rdd.PipedRDD + trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], @@ -43,9 +47,9 @@ trait PythonRDDBase { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) val dOut = new DataOutputStream(proc.getOutputStream) - out.println(broadcastVars.length) + dOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { - out.print(broadcast.uuid.toString) + dOut.writeLong(broadcast.id) dOut.writeInt(broadcast.value.length) dOut.write(broadcast.value) dOut.flush() diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index 6055bfd045..2ffe7f741d 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -5,7 +5,7 @@ import java.util.concurrent.atomic.AtomicLong import spark._ -abstract class Broadcast[T](id: Long) extends Serializable { +abstract class Broadcast[T](private[spark] val id: Long) extends Serializable { def value: T // We cannot have an abstract readObject here due to some weird issues with diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py index 1ea17d59af..4cff02b36d 100644 --- a/pyspark/pyspark/broadcast.py +++ b/pyspark/pyspark/broadcast.py @@ -6,7 +6,7 @@ [1, 2, 3, 4, 5] >>> from pyspark.broadcast import _broadcastRegistry ->>> _broadcastRegistry[b.uuid] = b +>>> _broadcastRegistry[b.bid] = b >>> from cPickle import dumps, loads >>> loads(dumps(b)).value [1, 2, 3, 4, 5] @@ -14,27 +14,27 @@ >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] """ -# Holds broadcasted data received from Java, keyed by UUID. +# Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} -def _from_uuid(uuid): +def _from_id(bid): from pyspark.broadcast import _broadcastRegistry - if uuid not in _broadcastRegistry: - raise Exception("Broadcast variable '%s' not loaded!" % uuid) - return _broadcastRegistry[uuid] + if bid not in _broadcastRegistry: + raise Exception("Broadcast variable '%s' not loaded!" % bid) + return _broadcastRegistry[bid] class Broadcast(object): - def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None): + def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): self.value = value - self.uuid = uuid + self.bid = bid self._jbroadcast = java_broadcast self._pickle_registry = pickle_registry def __reduce__(self): self._pickle_registry.add(self) - return (_from_uuid, (self.uuid, )) + return (_from_id, (self.bid, )) def _test(): diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 04932c93f2..3f4db26644 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -66,5 +66,5 @@ class SparkContext(object): def broadcast(self, value): jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) - return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, + return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py index bcb405ba72..3726bcbf17 100644 --- a/pyspark/pyspark/java_gateway.py +++ b/pyspark/pyspark/java_gateway.py @@ -7,7 +7,8 @@ SPARK_HOME = os.environ["SPARK_HOME"] assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \ - "/spark-core-assembly-*-SNAPSHOT.jar")[0] + "/spark-core-assembly-*.jar")[0] + # TODO: what if multiple assembly jars are found? def launch_gateway(): diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index faa1e683c7..21ef8b106c 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -9,16 +9,26 @@ def dump_pickle(obj): load_pickle = cPickle.loads +def read_long(stream): + length = stream.read(8) + if length == "": + raise EOFError + return struct.unpack("!q", length)[0] + + +def read_int(stream): + length = stream.read(4) + if length == "": + raise EOFError + return struct.unpack("!i", length)[0] + def write_with_length(obj, stream): stream.write(struct.pack("!i", len(obj))) stream.write(obj) def read_with_length(stream): - length = stream.read(4) - if length == "": - raise EOFError - length = struct.unpack("!i", length)[0] + length = read_int(stream) obj = stream.read(length) if obj == "": raise EOFError diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index a9ed71892f..62824a1c9b 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -8,7 +8,7 @@ from base64 import standard_b64decode from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.serializers import write_with_length, read_with_length, \ - dump_pickle, load_pickle + read_long, read_int, dump_pickle, load_pickle # Redirect stdout to stderr so that users must return values from functions. @@ -29,11 +29,11 @@ def read_input(): def main(): - num_broadcast_variables = int(sys.stdin.readline().strip()) + num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): - uuid = sys.stdin.read(36) + bid = read_long(sys.stdin) value = read_with_length(sys.stdin) - _broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value)) + _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) func = load_obj() bypassSerializer = load_obj() if bypassSerializer: -- cgit v1.2.3 From c23bf1aff4b9a1faf9d32c7b64acad2213f9515c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 20 Oct 2012 00:16:41 +0000 Subject: Add PySpark README and run scripts. --- core/src/main/scala/spark/SparkContext.scala | 2 +- pyspark/README | 58 ++++++++++++++++++++++++++++ pyspark/pyspark-shell | 3 ++ pyspark/pyspark/context.py | 5 +-- pyspark/pyspark/examples/wordcount.py | 17 ++++++++ pyspark/pyspark/shell.py | 21 ++++++++++ pyspark/run-pyspark | 23 +++++++++++ 7 files changed, 125 insertions(+), 4 deletions(-) create mode 100644 pyspark/README create mode 100755 pyspark/pyspark-shell create mode 100644 pyspark/pyspark/examples/wordcount.py create mode 100644 pyspark/pyspark/shell.py create mode 100755 pyspark/run-pyspark (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index becf737597..acb38ae33d 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -113,7 +113,7 @@ class SparkContext( // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", - "SPARK_TESTING")) { + "SPARK_TESTING", "PYTHONPATH")) { val value = System.getenv(key) if (value != null) { executorEnvs(key) = value diff --git a/pyspark/README b/pyspark/README new file mode 100644 index 0000000000..63a1def141 --- /dev/null +++ b/pyspark/README @@ -0,0 +1,58 @@ +# PySpark + +PySpark is a Python API for Spark. + +PySpark jobs are writen in Python and executed using a standard Python +interpreter; this supports modules that use Python C extensions. The +API is based on the Spark Scala API and uses regular Python functions +and lambdas to support user-defined functions. PySpark supports +interactive use through a standard Python interpreter; it can +automatically serialize closures and ship them to worker processes. + +PySpark is built on top of the Spark Java API. Data is uniformly +represented as serialized Python objects and stored in Spark Java +processes, which communicate with PySpark worker processes over pipes. + +## Features + +PySpark supports most of the Spark API, including broadcast variables. +RDDs are dynamically typed and can hold any Python object. + +PySpark does not support: + +- Special functions on RDDs of doubles +- Accumulators + +## Examples and Documentation + +The PySpark source contains docstrings and doctests that document its +API. The public classes are in `context.py` and `rdd.py`. + +The `pyspark/pyspark/examples` directory contains a few complete +examples. + +## Installing PySpark + +PySpark requires a development version of Py4J, a Python library for +interacting with Java processes. It can be installed from +https://github.com/bartdag/py4j; make sure to install a version that +contains at least the commits through 3dbf380d3d. + +PySpark uses the `PYTHONPATH` environment variable to search for Python +classes; Py4J should be on this path, along with any libraries used by +PySpark programs. `PYTHONPATH` will be automatically shipped to worker +machines, but the files that it points to must be present on each +machine. + +PySpark requires the Spark assembly JAR, which can be created by running +`sbt/sbt assembly` in the Spark directory. + +Additionally, `SPARK_HOME` should be set to the location of the Spark +package. + +## Running PySpark + +The easiest way to run PySpark is to use the `run-pyspark` and +`pyspark-shell` scripts, which are included in the `pyspark` directory. +These scripts automatically load the `spark-conf.sh` file, set +`SPARK_HOME`, and add the `pyspark` package to the `PYTHONPATH`. diff --git a/pyspark/pyspark-shell b/pyspark/pyspark-shell new file mode 100755 index 0000000000..4ed3e6010c --- /dev/null +++ b/pyspark/pyspark-shell @@ -0,0 +1,3 @@ +#!/bin/sh +FWDIR="`dirname $0`" +exec $FWDIR/run-pyspark $FWDIR/pyspark/shell.py "$@" diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 3f4db26644..50d57e5317 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -18,14 +18,13 @@ class SparkContext(object): asPickle = jvm.spark.api.python.PythonRDD.asPickle arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle - def __init__(self, master, name, defaultParallelism=None, - pythonExec='python'): + def __init__(self, master, name, defaultParallelism=None): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() - self.pythonExec = pythonExec + self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to diff --git a/pyspark/pyspark/examples/wordcount.py b/pyspark/pyspark/examples/wordcount.py new file mode 100644 index 0000000000..8365c070e8 --- /dev/null +++ b/pyspark/pyspark/examples/wordcount.py @@ -0,0 +1,17 @@ +import sys +from operator import add +from pyspark.context import SparkContext + +if __name__ == "__main__": + if len(sys.argv) < 3: + print >> sys.stderr, \ + "Usage: PythonWordCount " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonWordCount") + lines = sc.textFile(sys.argv[2], 1) + counts = lines.flatMap(lambda x: x.split(' ')) \ + .map(lambda x: (x, 1)) \ + .reduceByKey(add) + output = counts.collect() + for (word, count) in output: + print "%s : %i" % (word, count) diff --git a/pyspark/pyspark/shell.py b/pyspark/pyspark/shell.py new file mode 100644 index 0000000000..7ef30894cb --- /dev/null +++ b/pyspark/pyspark/shell.py @@ -0,0 +1,21 @@ +""" +An interactive shell. +""" +import code +import sys + +from pyspark.context import SparkContext + + +def main(master='local'): + sc = SparkContext(master, 'PySparkShell') + print "Spark context available as sc." + code.interact(local={'sc': sc}) + + +if __name__ == '__main__': + if len(sys.argv) > 1: + master = sys.argv[1] + else: + master = 'local' + main(master) diff --git a/pyspark/run-pyspark b/pyspark/run-pyspark new file mode 100755 index 0000000000..9c5e027962 --- /dev/null +++ b/pyspark/run-pyspark @@ -0,0 +1,23 @@ +#!/bin/bash + +# Figure out where the Scala framework is installed +FWDIR="$(cd `dirname $0`; cd ../; pwd)" + +# Export this as SPARK_HOME +export SPARK_HOME="$FWDIR" + +# Load environment variables from conf/spark-env.sh, if it exists +if [ -e $FWDIR/conf/spark-env.sh ] ; then + . $FWDIR/conf/spark-env.sh +fi + +# Figure out which Python executable to use +if [ -z "$PYSPARK_PYTHON" ] ; then + PYSPARK_PYTHON="python" +fi +export PYSPARK_PYTHON + +# Add the PySpark classes to the Python path: +export PYTHONPATH=$SPARK_HOME/pyspark/:$PYTHONPATH + +exec "$PYSPARK_PYTHON" "$@" -- cgit v1.2.3 From d4f2e5b0ef38db9d42bb0d5fbbbe6103ce047efe Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 22 Oct 2012 10:28:59 -0700 Subject: Remove PYTHONPATH from SparkContext's executorEnvs. It makes more sense to pass it in the dictionary of environment variables that is used to construct PythonRDD. --- core/src/main/scala/spark/SparkContext.scala | 2 +- core/src/main/scala/spark/api/python/PythonRDD.scala | 15 +++++++-------- pyspark/pyspark/rdd.py | 8 ++++++-- 3 files changed, 14 insertions(+), 11 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index acb38ae33d..becf737597 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -113,7 +113,7 @@ class SparkContext( // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", - "SPARK_TESTING", "PYTHONPATH")) { + "SPARK_TESTING")) { val value = System.getenv(key) if (value != null) { executorEnvs(key) = value diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 528885fe5c..a593e53efd 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -131,18 +131,17 @@ trait PythonRDDBase { } class PythonRDD[T: ClassManifest]( - parent: RDD[T], command: Seq[String], envVars: Map[String, String], + parent: RDD[T], command: Seq[String], envVars: java.util.Map[String, String], preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) with PythonRDDBase { - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, - pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = - this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) - // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) + def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], + preservePartitoning: Boolean, pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, + broadcastVars) override def splits = parent.splits @@ -151,7 +150,7 @@ class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None override def compute(split: Split): Iterator[Array[Byte]] = - compute(split, envVars, command, parent, pythonExec, broadcastVars) + compute(split, envVars.toMap, command, parent, pythonExec, broadcastVars) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index e2137fe06c..e4878c08ba 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,6 +1,7 @@ from base64 import standard_b64encode as b64enc from collections import defaultdict from itertools import chain, ifilter, imap +import os import shlex from subprocess import Popen, PIPE from threading import Thread @@ -10,7 +11,7 @@ from pyspark.serializers import dump_pickle, load_pickle from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup -from py4j.java_collections import ListConverter +from py4j.java_collections import ListConverter, MapConverter class RDD(object): @@ -447,8 +448,11 @@ class PipelinedRDD(RDD): self.ctx.gateway._gateway_client) self.ctx._pickled_broadcast_vars.clear() class_manifest = self._prev_jrdd.classManifest() + env = MapConverter().convert( + {'PYTHONPATH' : os.environ.get("PYTHONPATH", "")}, + self.ctx.gateway._gateway_client) python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, broadcast_vars, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val -- cgit v1.2.3 From 2ccf3b665280bf5b0919e3801d028126cb070dbd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 28 Oct 2012 22:30:28 -0700 Subject: Fix PySpark hash partitioning bug. A Java array's hashCode is based on its object identify, not its elements, so this was causing serialized keys to be hashed incorrectly. This commit adds a PySpark-specific workaround and adds more tests. --- .../scala/spark/api/python/PythonPartitioner.scala | 41 ++++++++++++++++++++++ .../main/scala/spark/api/python/PythonRDD.scala | 10 +++--- pyspark/pyspark/rdd.py | 12 +++++-- 3 files changed, 54 insertions(+), 9 deletions(-) create mode 100644 core/src/main/scala/spark/api/python/PythonPartitioner.scala (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala new file mode 100644 index 0000000000..ef9f808fb2 --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -0,0 +1,41 @@ +package spark.api.python + +import spark.Partitioner + +import java.util.Arrays + +/** + * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. + */ +class PythonPartitioner(override val numPartitions: Int) extends Partitioner { + + override def getPartition(key: Any): Int = { + if (key == null) { + return 0 + } + else { + val hashCode = { + if (key.isInstanceOf[Array[Byte]]) { + System.err.println("Dumping a byte array!" + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) + ) + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) + } + else + key.hashCode() + } + val mod = hashCode % numPartitions + if (mod < 0) { + mod + numPartitions + } else { + mod // Guard against negative hash codes + } + } + } + + override def equals(other: Any): Boolean = other match { + case h: PythonPartitioner => + h.numPartitions == numPartitions + case _ => + false + } +} diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index a593e53efd..50094d6b0f 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -179,14 +179,12 @@ object PythonRDD { val dOut = new DataOutputStream(baos); if (elem.isInstanceOf[Array[Byte]]) { elem.asInstanceOf[Array[Byte]] - } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { - val t = elem.asInstanceOf[scala.Tuple2[_, _]] - val t1 = t._1.asInstanceOf[Array[Byte]] - val t2 = t._2.asInstanceOf[Array[Byte]] + } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { + val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] dOut.writeByte(Pickle.PROTO) dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t1)) - dOut.write(PythonRDD.stripPickle(t2)) + dOut.write(PythonRDD.stripPickle(t._1)) + dOut.write(PythonRDD.stripPickle(t._2)) dOut.writeByte(Pickle.TUPLE2) dOut.writeByte(Pickle.STOP) baos.toByteArray() diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index e4878c08ba..85a24c6854 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -310,6 +310,12 @@ class RDD(object): return python_right_outer_join(self, other, numSplits) def partitionBy(self, numSplits, hashFunc=hash): + """ + >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) + >>> sets = pairs.partitionBy(2).glom().collect() + >>> set(sets[0]).intersection(set(sets[1])) + set([]) + """ if numSplits is None: numSplits = self.ctx.defaultParallelism def add_shuffle_key(iterator): @@ -319,7 +325,7 @@ class RDD(object): keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) + partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) return RDD(jrdd, self.ctx) @@ -391,7 +397,7 @@ class RDD(object): """ >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) - >>> x.cogroup(y).collect() + >>> sorted(x.cogroup(y).collect()) [('a', ([1], [2])), ('b', ([4], []))] """ return python_cogroup(self, other, numSplits) @@ -462,7 +468,7 @@ def _test(): import doctest from pyspark.context import SparkContext globs = globals().copy() - globs['sc'] = SparkContext('local', 'PythonTest') + globs['sc'] = SparkContext('local[4]', 'PythonTest') doctest.testmod(globs=globs) globs['sc'].stop() -- cgit v1.2.3 From ac12abc17ff90ec99192f3c3de4d1d390969e635 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 29 Oct 2012 11:55:27 -0700 Subject: Modified RDD API to make dependencies a var (therefore can be changed to checkpointed hadoop rdd) and othere references to parent RDDs either through dependencies or through a weak reference (to allow finalizing when dependencies do not refer to it any more). --- core/src/main/scala/spark/PairRDDFunctions.scala | 24 ++++++++++----------- core/src/main/scala/spark/ParallelCollection.scala | 8 +++---- core/src/main/scala/spark/RDD.scala | 25 ++++++++++++++++------ core/src/main/scala/spark/SparkContext.scala | 4 ++++ core/src/main/scala/spark/rdd/CartesianRDD.scala | 21 ++++++++++++------ core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 13 +++++++---- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 19 ++++++++++------ core/src/main/scala/spark/rdd/FilteredRDD.scala | 12 +++++++---- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 10 ++++----- core/src/main/scala/spark/rdd/GlommedRDD.scala | 9 ++++---- core/src/main/scala/spark/rdd/HadoopRDD.scala | 4 +--- .../main/scala/spark/rdd/MapPartitionsRDD.scala | 12 +++++------ .../spark/rdd/MapPartitionsWithSplitRDD.scala | 10 ++++----- core/src/main/scala/spark/rdd/MappedRDD.scala | 12 +++++------ core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 9 ++++---- core/src/main/scala/spark/rdd/PipedRDD.scala | 16 +++++++------- core/src/main/scala/spark/rdd/SampledRDD.scala | 15 ++++++------- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 13 ++++++----- core/src/main/scala/spark/rdd/UnionRDD.scala | 20 ++++++++++------- 19 files changed, 149 insertions(+), 107 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index e5bb639cfd..f52af08125 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -23,6 +23,7 @@ import spark.partial.BoundedDouble import spark.partial.PartialResult import spark.rdd._ import spark.SparkContext._ +import java.lang.ref.WeakReference /** * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. @@ -624,23 +625,22 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( } private[spark] -class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override val partitioner = prev.partitioner - override def compute(split: Split) = prev.iterator(split).map{case (k, v) => (k, f(v))} +class MappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => U) + extends RDD[(K, U)](prev.get) { + + override def splits = firstParent[(K, V)].splits + override val partitioner = firstParent[(K, V)].partitioner + override def compute(split: Split) = firstParent[(K, V)].iterator(split).map{case (k, v) => (k, f(v))} } private[spark] -class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]) - extends RDD[(K, U)](prev.context) { - - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override val partitioner = prev.partitioner +class FlatMappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) + extends RDD[(K, U)](prev.get) { + override def splits = firstParent[(K, V)].splits + override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split) = { - prev.iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } + firstParent[(K, V)].iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } } } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9b57ae3b4f..ad06ee9736 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -22,13 +22,13 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - sc: SparkContext, + @transient sc_ : SparkContext, @transient data: Seq[T], numSlices: Int) - extends RDD[T](sc) { + extends RDD[T](sc_, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split - // instead. + // instead. UPDATE: With the new changes to enable checkpointing, this an be done. @transient val splits_ = { @@ -41,8 +41,6 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def preferredLocations(s: Split): Seq[String] = Nil - - override val dependencies: List[Dependency[_]] = Nil } private object ParallelCollection { diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 338dff4061..c9f3763f73 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -72,7 +72,14 @@ import SparkContext._ * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details * on RDD internals. */ -abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serializable { +abstract class RDD[T: ClassManifest]( + @transient var sc: SparkContext, + @transient var dependencies_ : List[Dependency[_]] = Nil + ) extends Serializable { + + + def this(@transient oneParent: RDD[_]) = + this(oneParent.context , List(new OneToOneDependency(oneParent))) // Methods that must be implemented by subclasses: @@ -83,10 +90,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def compute(split: Split): Iterator[T] /** How this RDD depends on any parent RDDs. */ - @transient val dependencies: List[Dependency[_]] + def dependencies: List[Dependency[_]] = dependencies_ + //var dependencies: List[Dependency[_]] = dependencies_ - // Methods available on all RDDs: - /** Record user function generating this RDD. */ private[spark] val origin = Utils.getSparkCallSite @@ -106,8 +112,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE - - /** + + private[spark] def firstParent[U: ClassManifest] = dependencies.head.rdd.asInstanceOf[RDD[U]] + private[spark] def parent[U: ClassManifest](id: Int) = dependencies(id).rdd.asInstanceOf[RDD[U]] + + // Methods available on all RDDs: + + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. */ @@ -129,7 +140,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - + private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { if (!level.useDisk && level.replication < 2) { throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 0d37075ef3..6b957a6356 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -3,6 +3,7 @@ package spark import java.io._ import java.util.concurrent.atomic.AtomicInteger import java.net.{URI, URLClassLoader} +import java.lang.ref.WeakReference import scala.collection.Map import scala.collection.generic.Growable @@ -695,6 +696,9 @@ object SparkContext { /** Find the JAR that contains the class of a particular object */ def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass) + + implicit def rddToWeakRefRDD[T: ClassManifest](rdd: RDD[T]) = new WeakReference(rdd) + } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 7c354b6b2e..c97b835630 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -4,6 +4,7 @@ import spark.NarrowDependency import spark.RDD import spark.SparkContext import spark.Split +import java.lang.ref.WeakReference private[spark] class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { @@ -13,13 +14,17 @@ class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with private[spark] class CartesianRDD[T: ClassManifest, U:ClassManifest]( sc: SparkContext, - rdd1: RDD[T], - rdd2: RDD[U]) + rdd1_ : WeakReference[RDD[T]], + rdd2_ : WeakReference[RDD[U]]) extends RDD[Pair[T, U]](sc) with Serializable { - + + def rdd1 = rdd1_.get + def rdd2 = rdd2_.get + val numSplitsInRdd2 = rdd2.splits.size - + + // TODO: make this null when finishing checkpoint @transient val splits_ = { // create the cross product split @@ -31,6 +36,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( array } + // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override def preferredLocations(split: Split) = { @@ -42,8 +48,9 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( val currSplit = split.asInstanceOf[CartesianSplit] for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y) } - - override val dependencies = List( + + // TODO: make this null when finishing checkpoint + var deps = List( new NarrowDependency(rdd1) { def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) }, @@ -51,4 +58,6 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2) } ) + + override def dependencies = deps } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 50bec9e63b..af54ac2fa0 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -31,12 +31,13 @@ private[spark] class CoGroupAggregator with Serializable class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) - extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging { + extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging { val aggr = new CoGroupAggregator - + + // TODO: make this null when finishing checkpoint @transient - override val dependencies = { + var deps = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) @@ -50,7 +51,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) } deps.toList } - + + override def dependencies = deps + + // TODO: make this null when finishing checkpoint @transient val splits_ : Array[Split] = { val firstRdd = rdds.head @@ -68,6 +72,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) array } + // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override val partitioner = Some(part) diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 0967f4f5df..573acf8893 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -14,11 +14,14 @@ private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) exten * This transformation is useful when an RDD with many partitions gets filtered into a smaller one, * or to avoid having a large number of small tasks when processing a directory with many files. */ -class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int) - extends RDD[T](prev.context) { +class CoalescedRDD[T: ClassManifest]( + @transient prev: RDD[T], // TODO: Make this a weak reference + maxPartitions: Int) + extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs + // TODO: make this null when finishing checkpoint @transient val splits_ : Array[Split] = { - val prevSplits = prev.splits + val prevSplits = firstParent[T].splits if (prevSplits.length < maxPartitions) { prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) } } else { @@ -30,18 +33,22 @@ class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int) } } + // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override def compute(split: Split): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { - parentSplit => prev.iterator(parentSplit) + parentSplit => firstParent[T].iterator(parentSplit) } } - val dependencies = List( - new NarrowDependency(prev) { + // TODO: make this null when finishing checkpoint + var deps = List( + new NarrowDependency(firstParent) { def getParents(id: Int): Seq[Int] = splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index) } ) + + override def dependencies = deps } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index dfe9dc73f3..cc2a3acd3a 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -3,10 +3,14 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] -class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).filter(f) +class FilteredRDD[T: ClassManifest]( + @transient prev: WeakReference[RDD[T]], + f: T => Boolean) + extends RDD[T](prev.get) { + + override def splits = firstParent[T].splits + override def compute(split: Split) = firstParent[T].iterator(split).filter(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 3534dc8057..34bd784c13 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -3,14 +3,14 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], f: T => TraversableOnce[U]) - extends RDD[U](prev.context) { + extends RDD[U](prev.get) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).flatMap(f) + override def splits = firstParent[T].splits + override def compute(split: Split) = firstParent[T].iterator(split).flatMap(f) } diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index e30564f2da..9321e89dcd 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -3,10 +3,11 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] -class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator +class GlommedRDD[T: ClassManifest](@transient prev: WeakReference[RDD[T]]) + extends RDD[Array[T]](prev.get) { + override def splits = firstParent[T].splits + override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index bf29a1f075..a12531ea89 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -46,7 +46,7 @@ class HadoopRDD[K, V]( keyClass: Class[K], valueClass: Class[V], minSplits: Int) - extends RDD[(K, V)](sc) { + extends RDD[(K, V)](sc, Nil) { // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it val confBroadcast = sc.broadcast(new SerializableWritable(conf)) @@ -115,6 +115,4 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } - - override val dependencies: List[Dependency[_]] = Nil } diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index a904ef62c3..bad872c430 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -3,17 +3,17 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false) - extends RDD[U](prev.context) { + extends RDD[U](prev.get) { - override val partitioner = if (preservesPartitioning) prev.partitioner else None + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = f(prev.iterator(split)) + override def splits = firstParent[T].splits + override def compute(split: Split) = f(firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index adc541694e..d7b238b05d 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -3,6 +3,7 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference /** * A variant of the MapPartitionsRDD that passes the split index into the @@ -11,11 +12,10 @@ import spark.Split */ private[spark] class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], f: (Int, Iterator[T]) => Iterator[U]) - extends RDD[U](prev.context) { + extends RDD[U](prev.get) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = f(split.index, prev.iterator(split)) + override def splits = firstParent[T].splits + override def compute(split: Split) = f(split.index, firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 59bedad8ef..126c6f332b 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -3,14 +3,14 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], f: T => U) - extends RDD[U](prev.context) { - - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).map(f) + extends RDD[U](prev.get) { + + override def splits = firstParent[T].splits + override def compute(split: Split) = firstParent[T].iterator(split).map(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index 7a1a0fb87d..c12df5839e 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -23,11 +23,12 @@ class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit } class NewHadoopRDD[K, V]( - sc: SparkContext, + sc : SparkContext, inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], valueClass: Class[V], + keyClass: Class[K], + valueClass: Class[V], @transient conf: Configuration) - extends RDD[(K, V)](sc) + extends RDD[(K, V)](sc, Nil) with HadoopMapReduceUtil { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it @@ -92,6 +93,4 @@ class NewHadoopRDD[K, V]( val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } - - override val dependencies: List[Dependency[_]] = Nil } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 98ea0c92d6..d54579d6d1 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -19,18 +19,18 @@ import spark.Split * (printing them one per line) and returns the output as a collection of strings. */ class PipedRDD[T: ClassManifest]( - parent: RDD[T], command: Seq[String], envVars: Map[String, String]) - extends RDD[String](parent.context) { + @transient prev: RDD[T], + command: Seq[String], + envVars: Map[String, String]) + extends RDD[String](prev) { - def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map()) + def this(@transient prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command)) + def this(@transient prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) - override def splits = parent.splits - - override val dependencies = List(new OneToOneDependency(parent)) + override def splits = firstParent[T].splits override def compute(split: Split): Iterator[String] = { val pb = new ProcessBuilder(command) @@ -55,7 +55,7 @@ class PipedRDD[T: ClassManifest]( override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) - for (elem <- parent.iterator(split)) { + for (elem <- firstParent[T].iterator(split)) { out.println(elem) } out.close() diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 87a5268f27..00b521b130 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -7,6 +7,7 @@ import cern.jet.random.engine.DRand import spark.RDD import spark.OneToOneDependency import spark.Split +import java.lang.ref.WeakReference private[spark] class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { @@ -14,24 +15,22 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali } class SampledRDD[T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], withReplacement: Boolean, frac: Double, seed: Int) - extends RDD[T](prev.context) { + extends RDD[T](prev.get) { @transient val splits_ = { val rg = new Random(seed) - prev.splits.map(x => new SampledRDDSplit(x, rg.nextInt)) + firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } override def splits = splits_.asInstanceOf[Array[Split]] - override val dependencies = List(new OneToOneDependency(prev)) - override def preferredLocations(split: Split) = - prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) + firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) override def compute(splitIn: Split) = { val split = splitIn.asInstanceOf[SampledRDDSplit] @@ -39,7 +38,7 @@ class SampledRDD[T: ClassManifest]( // For large datasets, the expected number of occurrences of each element in a sample with // replacement is Poisson(frac). We use that to get a count for each element. val poisson = new Poisson(frac, new DRand(split.seed)) - prev.iterator(split.prev).flatMap { element => + firstParent[T].iterator(split.prev).flatMap { element => val count = poisson.nextInt() if (count == 0) { Iterator.empty // Avoid object allocation when we return 0 items, which is quite often @@ -49,7 +48,7 @@ class SampledRDD[T: ClassManifest]( } } else { // Sampling without replacement val rand = new Random(split.seed) - prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac)) + firstParent[T].iterator(split.prev).filter(x => (rand.nextDouble <= frac)) } } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 145e419c53..62867dab4f 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -5,6 +5,7 @@ import spark.RDD import spark.ShuffleDependency import spark.SparkEnv import spark.Split +import java.lang.ref.WeakReference private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { override val index = idx @@ -19,8 +20,9 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { * @tparam V the value class. */ class ShuffledRDD[K, V]( - @transient parent: RDD[(K, V)], - part: Partitioner) extends RDD[(K, V)](parent.context) { + @transient prev: WeakReference[RDD[(K, V)]], + part: Partitioner) + extends RDD[(K, V)](prev.get.context, List(new ShuffleDependency(prev.get, part))) { override val partitioner = Some(part) @@ -31,10 +33,11 @@ class ShuffledRDD[K, V]( override def preferredLocations(split: Split) = Nil - val dep = new ShuffleDependency(parent, part) - override val dependencies = List(dep) + //val dep = new ShuffleDependency(parent, part) + //override val dependencies = List(dep) override def compute(split: Split): Iterator[(K, V)] = { - SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index) + val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId + SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index f0b9225f7c..0a61a2d1f5 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -7,6 +7,7 @@ import spark.RangeDependency import spark.RDD import spark.SparkContext import spark.Split +import java.lang.ref.WeakReference private[spark] class UnionSplit[T: ClassManifest]( idx: Int, @@ -22,10 +23,10 @@ private[spark] class UnionSplit[T: ClassManifest]( class UnionRDD[T: ClassManifest]( sc: SparkContext, - @transient rdds: Seq[RDD[T]]) - extends RDD[T](sc) - with Serializable { - + @transient rdds: Seq[RDD[T]]) // TODO: Make this a weak reference + extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs + + // TODO: make this null when finishing checkpoint @transient val splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) @@ -37,19 +38,22 @@ class UnionRDD[T: ClassManifest]( array } + // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ - @transient - override val dependencies = { + // TODO: make this null when finishing checkpoint + @transient var deps = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) + deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) pos += rdd.splits.size } deps.toList } - + + override def dependencies = deps + override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() override def preferredLocations(s: Split): Seq[String] = -- cgit v1.2.3 From 531ac136bf4ed333cb906ac229d986605a8207a6 Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 29 Oct 2012 14:53:47 -0700 Subject: BlockManager UI. --- core/src/main/scala/spark/RDD.scala | 8 ++ core/src/main/scala/spark/SparkContext.scala | 10 ++ .../scala/spark/storage/BlockManagerMaster.scala | 33 ++++++- .../main/scala/spark/storage/BlockManagerUI.scala | 102 +++++++++++++++++++++ core/src/main/scala/spark/util/AkkaUtils.scala | 5 +- core/src/main/twirl/spark/common/layout.scala.html | 35 +++++++ .../twirl/spark/deploy/common/layout.scala.html | 35 ------- .../twirl/spark/deploy/master/index.scala.html | 2 +- .../spark/deploy/master/job_details.scala.html | 2 +- .../twirl/spark/deploy/worker/index.scala.html | 2 +- core/src/main/twirl/spark/storage/index.scala.html | 28 ++++++ core/src/main/twirl/spark/storage/rdd.scala.html | 65 +++++++++++++ .../main/twirl/spark/storage/rdd_row.scala.html | 18 ++++ .../main/twirl/spark/storage/rdd_table.scala.html | 18 ++++ 14 files changed, 318 insertions(+), 45 deletions(-) create mode 100644 core/src/main/scala/spark/storage/BlockManagerUI.scala create mode 100644 core/src/main/twirl/spark/common/layout.scala.html delete mode 100644 core/src/main/twirl/spark/deploy/common/layout.scala.html create mode 100644 core/src/main/twirl/spark/storage/index.scala.html create mode 100644 core/src/main/twirl/spark/storage/rdd.scala.html create mode 100644 core/src/main/twirl/spark/storage/rdd_row.scala.html create mode 100644 core/src/main/twirl/spark/storage/rdd_table.scala.html (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 338dff4061..dc757dc6aa 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -107,6 +107,12 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE + /* Assign a name to this RDD */ + def name(name: String) = { + sc.rddNames(this.id) = name + this + } + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. @@ -118,6 +124,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial "Cannot change storage level of an RDD after it was already assigned a level") } storageLevel = newLevel + // Register the RDD with the SparkContext + sc.persistentRdds(id) = this this } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d26cccbfe1..71c9dcd017 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -1,6 +1,7 @@ package spark import java.io._ +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger import java.net.{URI, URLClassLoader} @@ -102,10 +103,19 @@ class SparkContext( isLocal) SparkEnv.set(env) + // Start the BlockManager UI + spark.storage.BlockManagerUI.start(SparkEnv.get.actorSystem, + SparkEnv.get.blockManager.master.masterActor, this) + // Used to store a URL for each static file/jar together with the file's local timestamp private[spark] val addedFiles = HashMap[String, Long]() private[spark] val addedJars = HashMap[String, Long]() + // Keeps track of all persisted RDDs + private[spark] val persistentRdds = new ConcurrentHashMap[Int, RDD[_]]() + // A HashMap for friendly RDD Names + private[spark] val rddNames = new ConcurrentHashMap[Int, String]() + // Add each JAR given through the constructor jars.foreach { addJar(_) } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index ace27e758c..d12a16869a 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -3,7 +3,8 @@ package spark.storage import java.io._ import java.util.{HashMap => JHashMap} -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.util.Random import akka.actor._ @@ -90,6 +91,15 @@ case object StopBlockManagerMaster extends ToBlockManagerMaster private[spark] case object GetMemoryStatus extends ToBlockManagerMaster +private[spark] +case class GetStorageStatus extends ToBlockManagerMaster + +private[spark] +case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) + +private[spark] +case class StorageStatus(maxMem: Long, remainingMem: Long, blocks: Map[String, BlockStatus]) + private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { @@ -99,7 +109,8 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor val maxMem: Long) { private var _lastSeenMs = timeMs private var _remainingMem = maxMem - private val _blocks = new JHashMap[String, StorageLevel] + + private val _blocks = new JHashMap[String, BlockStatus] logInfo("Registering block manager %s:%d with %s RAM".format( blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) @@ -115,7 +126,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (_blocks.containsKey(blockId)) { // The block exists on the slave already. - val originalLevel: StorageLevel = _blocks.get(blockId) + val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel if (originalLevel.useMemory) { _remainingMem += memSize @@ -124,7 +135,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (storageLevel.isValid) { // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, storageLevel) + _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) if (storageLevel.useMemory) { _remainingMem -= memSize logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( @@ -137,7 +148,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor } } else if (_blocks.containsKey(blockId)) { // If isValid is not true, drop the block. - val originalLevel: StorageLevel = _blocks.get(blockId) + val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel _blocks.remove(blockId) if (originalLevel.useMemory) { _remainingMem += memSize @@ -152,6 +163,8 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor } } + def blocks: JHashMap[String, BlockStatus] = _blocks + def remainingMem: Long = _remainingMem def lastSeenMs: Long = _lastSeenMs @@ -198,6 +211,9 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor case GetMemoryStatus => getMemoryStatus + case GetStorageStatus => + getStorageStatus + case RemoveHost(host) => removeHost(host) sender ! true @@ -219,6 +235,13 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor sender ! res } + private def getStorageStatus() { + val res = blockManagerInfo.map { case(blockManagerId, info) => + StorageStatus(info.maxMem, info.remainingMem, info.blocks.asScala) + } + sender ! res + } + private def register(blockManagerId: BlockManagerId, maxMemSize: Long) { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockManagerId + " " diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala new file mode 100644 index 0000000000..c168f60c35 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -0,0 +1,102 @@ +package spark.storage + +import akka.actor.{ActorRef, ActorSystem} +import akka.dispatch.Await +import akka.pattern.ask +import akka.util.Timeout +import akka.util.duration._ +import cc.spray.Directives +import cc.spray.directives._ +import cc.spray.typeconversion.TwirlSupport._ +import scala.collection.mutable.ArrayBuffer +import spark.{Logging, SparkContext, SparkEnv} +import spark.util.AkkaUtils + +private[spark] +object BlockManagerUI extends Logging { + + /* Starts the Web interface for the BlockManager */ + def start(actorSystem : ActorSystem, masterActor: ActorRef, sc: SparkContext) { + val webUIDirectives = new BlockManagerUIDirectives(actorSystem, masterActor, sc) + try { + logInfo("Starting BlockManager WebUI.") + val port = Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt + AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, webUIDirectives.handler, "BlockManagerHTTPServer") + } catch { + case e: Exception => + logError("Failed to create BlockManager WebUI", e) + System.exit(1) + } + } + +} + +private[spark] +case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, numPartitions: Int, memSize: Long, diskSize: Long) + +private[spark] +class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, sc: SparkContext) extends Directives { + + val STATIC_RESOURCE_DIR = "spark/deploy/static" + implicit val timeout = Timeout(1 seconds) + + val handler = { + + get { path("") { completeWith { + // Request the current storage status from the Master + val future = master ? GetStorageStatus + future.map { status => + val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] + + // Calculate macro-level statistics + val maxMem = storageStati.map(_.maxMem).reduce(_+_) + val remainingMem = storageStati.map(_.remainingMem).reduce(_+_) + val diskSpaceUsed = storageStati.flatMap(_.blocks.values.map(_.diskSize)) + .reduceOption(_+_).getOrElse(0L) + + // Filter out everything that's not and rdd. + val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => k.startsWith("rdd") }.toMap + val rdds = rddInfoFromBlockStati(rddBlocks) + + spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds.toList) + } + }}} ~ + get { path("rdd") { parameter("id") { id => { completeWith { + val future = master ? GetStorageStatus + future.map { status => + val prefix = "rdd_" + id.toString + + val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] + val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => k.startsWith(prefix) }.toMap + val rddInfo = rddInfoFromBlockStati(rddBlocks).first + + spark.storage.html.rdd.render(rddInfo, rddBlocks) + + } + }}}}} ~ + pathPrefix("static") { + getFromResourceDirectory(STATIC_RESOURCE_DIR) + } + + } + + private def rddInfoFromBlockStati(infos: Map[String, BlockStatus]) : Array[RDDInfo] = { + infos.groupBy { case(k,v) => + // Group by rdd name, ignore the partition name + k.substring(0,k.lastIndexOf('_')) + }.map { case(k,v) => + val blockStati = v.map(_._2).toArray + // Add up memory and disk sizes + val tmp = blockStati.map { x => (x.memSize, x.diskSize)}.reduce { (x,y) => + (x._1 + y._1, x._2 + y._2) + } + // Get the friendly name for the rdd, if available. + // This is pretty hacky, is there a better way? + val rddId = k.split("_").last.toInt + val rddName : String = Option(sc.rddNames.get(rddId)).getOrElse(k) + val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel + RDDInfo(rddId, rddName, rddStorageLevel, blockStati.length, tmp._1, tmp._2) + }.toArray + } + +} diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index b466b5239c..13bc0f8ccc 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -50,12 +50,13 @@ private[spark] object AkkaUtils { * Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to * handle requests. Throws a SparkException if this fails. */ - def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route) { + def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, + name: String = "HttpServer") { val ioWorker = new IoWorker(actorSystem).start() val httpService = actorSystem.actorOf(Props(new HttpService(route))) val rootService = actorSystem.actorOf(Props(new SprayCanRootService(httpService))) val server = actorSystem.actorOf( - Props(new HttpServer(ioWorker, SingletonHandler(rootService))), name = "HttpServer") + Props(new HttpServer(ioWorker, SingletonHandler(rootService))), name = name) actorSystem.registerOnTermination { ioWorker.stop() } val timeout = 3.seconds val future = server.ask(HttpServer.Bind(ip, port))(timeout) diff --git a/core/src/main/twirl/spark/common/layout.scala.html b/core/src/main/twirl/spark/common/layout.scala.html new file mode 100644 index 0000000000..b9192060aa --- /dev/null +++ b/core/src/main/twirl/spark/common/layout.scala.html @@ -0,0 +1,35 @@ +@(title: String)(content: Html) + + + + + + + + + + @title + + + + +
+ + +
+
+ +

@title

+
+
+ +
+ + @content + +
+ + + \ No newline at end of file diff --git a/core/src/main/twirl/spark/deploy/common/layout.scala.html b/core/src/main/twirl/spark/deploy/common/layout.scala.html deleted file mode 100644 index b9192060aa..0000000000 --- a/core/src/main/twirl/spark/deploy/common/layout.scala.html +++ /dev/null @@ -1,35 +0,0 @@ -@(title: String)(content: Html) - - - - - - - - - - @title - - - - -
- - -
-
- -

@title

-
-
- -
- - @content - -
- - - \ No newline at end of file diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html index 7562076b00..2e15fe2200 100644 --- a/core/src/main/twirl/spark/deploy/master/index.scala.html +++ b/core/src/main/twirl/spark/deploy/master/index.scala.html @@ -1,7 +1,7 @@ @(state: spark.deploy.MasterState) @import spark.deploy.master._ -@spark.deploy.common.html.layout(title = "Spark Master on " + state.uri) { +@spark.common.html.layout(title = "Spark Master on " + state.uri) {
diff --git a/core/src/main/twirl/spark/deploy/master/job_details.scala.html b/core/src/main/twirl/spark/deploy/master/job_details.scala.html index dcf41c28f2..d02a51b214 100644 --- a/core/src/main/twirl/spark/deploy/master/job_details.scala.html +++ b/core/src/main/twirl/spark/deploy/master/job_details.scala.html @@ -1,6 +1,6 @@ @(job: spark.deploy.master.JobInfo) -@spark.deploy.common.html.layout(title = "Job Details") { +@spark.common.html.layout(title = "Job Details") {
diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html index 69746ed02c..40c2d81d77 100644 --- a/core/src/main/twirl/spark/deploy/worker/index.scala.html +++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html @@ -1,6 +1,6 @@ @(worker: spark.deploy.WorkerState) -@spark.deploy.common.html.layout(title = "Spark Worker on " + worker.uri) { +@spark.common.html.layout(title = "Spark Worker on " + worker.uri) {
diff --git a/core/src/main/twirl/spark/storage/index.scala.html b/core/src/main/twirl/spark/storage/index.scala.html new file mode 100644 index 0000000000..fa7dad51ee --- /dev/null +++ b/core/src/main/twirl/spark/storage/index.scala.html @@ -0,0 +1,28 @@ +@(maxMem: Long, remainingMem: Long, diskSpaceUsed: Long, rdds: List[spark.storage.RDDInfo]) + +@spark.common.html.layout(title = "Storage Dashboard") { + + +
+
+
    +
  • Memory: + @{spark.Utils.memoryBytesToString(maxMem - remainingMem)} Used + (@{spark.Utils.memoryBytesToString(remainingMem)} Available)
  • +
  • Disk: @{spark.Utils.memoryBytesToString(diskSpaceUsed)} Used
  • +
+
+
+ +
+ + +
+
+

RDD Summary

+
+ @rdd_table(rdds) +
+
+ +} \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html new file mode 100644 index 0000000000..3a70326efe --- /dev/null +++ b/core/src/main/twirl/spark/storage/rdd.scala.html @@ -0,0 +1,65 @@ +@(rddInfo: spark.storage.RDDInfo, blocks: Map[String, spark.storage.BlockStatus]) + +@spark.common.html.layout(title = "RDD Info ") { + + +
+
+
    +
  • + Storage Level: + @(if (rddInfo.storageLevel.useDisk) "Disk" else "") + @(if (rddInfo.storageLevel.useMemory) "Memory" else "") + @(if (rddInfo.storageLevel.deserialized) "Deserialized" else "") + @(rddInfo.storageLevel.replication)x Replicated +
  • + Partitions: + @(rddInfo.numPartitions) +
  • +
  • + Memory Size: + @{spark.Utils.memoryBytesToString(rddInfo.memSize)} +
  • +
  • + Disk Size: + @{spark.Utils.memoryBytesToString(rddInfo.diskSize)} +
  • +
+
+
+ +
+ + +
+
+

RDD Summary

+
+ + + + + + + + + + + + + @blocks.map { case (k,v) => + + + + + + + } + +
Block NameStorage LevelSize in MemorySize on Disk
@k@v.storageLevel@{spark.Utils.memoryBytesToString(v.memSize)}@{spark.Utils.memoryBytesToString(v.diskSize)}
+ + +
+
+ +} \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd_row.scala.html b/core/src/main/twirl/spark/storage/rdd_row.scala.html new file mode 100644 index 0000000000..3dd9944e3b --- /dev/null +++ b/core/src/main/twirl/spark/storage/rdd_row.scala.html @@ -0,0 +1,18 @@ +@(rdd: spark.storage.RDDInfo) + + + + + @rdd.name + + + + @(if (rdd.storageLevel.useDisk) "Disk" else "") + @(if (rdd.storageLevel.useMemory) "Memory" else "") + @(if (rdd.storageLevel.deserialized) "Deserialized" else "") + @(rdd.storageLevel.replication)x Replicated + + @rdd.numPartitions + @{spark.Utils.memoryBytesToString(rdd.memSize)} + @{spark.Utils.memoryBytesToString(rdd.diskSize)} + \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd_table.scala.html b/core/src/main/twirl/spark/storage/rdd_table.scala.html new file mode 100644 index 0000000000..24f55ccefb --- /dev/null +++ b/core/src/main/twirl/spark/storage/rdd_table.scala.html @@ -0,0 +1,18 @@ +@(rdds: List[spark.storage.RDDInfo]) + + + + + + + + + + + + + @for(rdd <- rdds) { + @rdd_row(rdd) + } + +
RDD NameStorage LevelPartitionsSize in MemorySize on Disk
\ No newline at end of file -- cgit v1.2.3 From eb95212f4d24dbcd734922f39d51e6fdeaeb4c8b Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 29 Oct 2012 14:57:32 -0700 Subject: code Formatting --- core/src/main/scala/spark/storage/BlockManagerUI.scala | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index c168f60c35..635c096c87 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -21,7 +21,8 @@ object BlockManagerUI extends Logging { try { logInfo("Starting BlockManager WebUI.") val port = Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt - AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, webUIDirectives.handler, "BlockManagerHTTPServer") + AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, + webUIDirectives.handler, "BlockManagerHTTPServer") } catch { case e: Exception => logError("Failed to create BlockManager WebUI", e) @@ -32,10 +33,12 @@ object BlockManagerUI extends Logging { } private[spark] -case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, numPartitions: Int, memSize: Long, diskSize: Long) +case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, + numPartitions: Int, memSize: Long, diskSize: Long) private[spark] -class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, sc: SparkContext) extends Directives { +class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, + sc: SparkContext) extends Directives { val STATIC_RESOURCE_DIR = "spark/deploy/static" implicit val timeout = Timeout(1 seconds) @@ -55,7 +58,9 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, s .reduceOption(_+_).getOrElse(0L) // Filter out everything that's not and rdd. - val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => k.startsWith("rdd") }.toMap + val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => + k.startsWith("rdd") + }.toMap val rdds = rddInfoFromBlockStati(rddBlocks) spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds.toList) @@ -67,7 +72,9 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, s val prefix = "rdd_" + id.toString val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] - val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => k.startsWith(prefix) }.toMap + val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => + k.startsWith(prefix) + }.toMap val rddInfo = rddInfoFromBlockStati(rddBlocks).first spark.storage.html.rdd.render(rddInfo, rddBlocks) -- cgit v1.2.3 From ceec1a1a6abb1fd03316e7fcc532d7e121d5bf65 Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 29 Oct 2012 15:03:01 -0700 Subject: Nicer storage level format on RDD page --- core/src/main/twirl/spark/storage/rdd.scala.html | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html index 3a70326efe..075289c826 100644 --- a/core/src/main/twirl/spark/storage/rdd.scala.html +++ b/core/src/main/twirl/spark/storage/rdd.scala.html @@ -50,7 +50,12 @@ @blocks.map { case (k,v) => @k - @v.storageLevel + + @(if (v.storageLevel.useDisk) "Disk" else "") + @(if (v.storageLevel.useMemory) "Memory" else "") + @(if (v.storageLevel.deserialized) "Deserialized" else "") + @(v.storageLevel.replication)x Replicated + @{spark.Utils.memoryBytesToString(v.memSize)} @{spark.Utils.memoryBytesToString(v.diskSize)} -- cgit v1.2.3 From 0dcd770fdc4d558972b635b6770ed0120280ef22 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 30 Oct 2012 16:09:37 -0700 Subject: Added checkpointing support to all RDDs, along with CheckpointSuite to test checkpointing in them. --- core/src/main/scala/spark/PairRDDFunctions.scala | 4 +- core/src/main/scala/spark/ParallelCollection.scala | 4 +- core/src/main/scala/spark/RDD.scala | 129 +++++++++++++++++---- core/src/main/scala/spark/SparkContext.scala | 21 ++++ core/src/main/scala/spark/rdd/BlockRDD.scala | 13 ++- core/src/main/scala/spark/rdd/CartesianRDD.scala | 38 +++--- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 19 +-- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 26 +++-- core/src/main/scala/spark/rdd/FilteredRDD.scala | 2 +- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 2 +- core/src/main/scala/spark/rdd/GlommedRDD.scala | 2 +- core/src/main/scala/spark/rdd/HadoopRDD.scala | 2 + .../main/scala/spark/rdd/MapPartitionsRDD.scala | 2 +- .../spark/rdd/MapPartitionsWithSplitRDD.scala | 2 +- core/src/main/scala/spark/rdd/MappedRDD.scala | 2 +- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 2 + core/src/main/scala/spark/rdd/PipedRDD.scala | 9 +- core/src/main/scala/spark/rdd/SampledRDD.scala | 2 +- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 5 - core/src/main/scala/spark/rdd/UnionRDD.scala | 32 ++--- core/src/test/scala/spark/CheckpointSuite.scala | 116 ++++++++++++++++++ core/src/test/scala/spark/RDDSuite.scala | 25 +++- 22 files changed, 352 insertions(+), 107 deletions(-) create mode 100644 core/src/test/scala/spark/CheckpointSuite.scala (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index f52af08125..1f82bd3ab8 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -625,7 +625,7 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( } private[spark] -class MappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => U) +class MappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => U) extends RDD[(K, U)](prev.get) { override def splits = firstParent[(K, V)].splits @@ -634,7 +634,7 @@ class MappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V } private[spark] -class FlatMappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) +class FlatMappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) extends RDD[(K, U)](prev.get) { override def splits = firstParent[(K, V)].splits diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index ad06ee9736..9725017b61 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -22,10 +22,10 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - @transient sc_ : SparkContext, + @transient sc : SparkContext, @transient data: Seq[T], numSlices: Int) - extends RDD[T](sc_, Nil) { + extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split // instead. UPDATE: With the new changes to enable checkpointing, this an be done. diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index c9f3763f73..e272a0ede9 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -13,6 +13,7 @@ import scala.collection.Map import scala.collection.mutable.HashMap import scala.collection.JavaConversions.mapAsScalaMap +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text @@ -74,7 +75,7 @@ import SparkContext._ */ abstract class RDD[T: ClassManifest]( @transient var sc: SparkContext, - @transient var dependencies_ : List[Dependency[_]] = Nil + var dependencies_ : List[Dependency[_]] ) extends Serializable { @@ -91,7 +92,6 @@ abstract class RDD[T: ClassManifest]( /** How this RDD depends on any parent RDDs. */ def dependencies: List[Dependency[_]] = dependencies_ - //var dependencies: List[Dependency[_]] = dependencies_ /** Record user function generating this RDD. */ private[spark] val origin = Utils.getSparkCallSite @@ -100,7 +100,13 @@ abstract class RDD[T: ClassManifest]( val partitioner: Option[Partitioner] = None /** Optionally overridden by subclasses to specify placement preferences. */ - def preferredLocations(split: Split): Seq[String] = Nil + def preferredLocations(split: Split): Seq[String] = { + if (isCheckpointed) { + checkpointRDD.preferredLocations(split) + } else { + Nil + } + } /** The [[spark.SparkContext]] that this RDD was created on. */ def context = sc @@ -113,8 +119,23 @@ abstract class RDD[T: ClassManifest]( // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE - private[spark] def firstParent[U: ClassManifest] = dependencies.head.rdd.asInstanceOf[RDD[U]] - private[spark] def parent[U: ClassManifest](id: Int) = dependencies(id).rdd.asInstanceOf[RDD[U]] + /** Returns the first parent RDD */ + private[spark] def firstParent[U: ClassManifest] = { + dependencies.head.rdd.asInstanceOf[RDD[U]] + } + + /** Returns the `i` th parent RDD */ + private[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] + + // Variables relating to checkpointing + val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD + var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing + var isCheckpointInProgress = false // set to true when checkpointing is in progress + var isCheckpointed = false // set to true after checkpointing is completed + + var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed + var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file + var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD // Methods available on all RDDs: @@ -141,32 +162,94 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { - if (!level.useDisk && level.replication < 2) { - throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") - } - - // This is a hack. Ideally this should re-use the code used by the CacheTracker - // to generate the key. - def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index) - - persist(level) - sc.runJob(this, (iter: Iterator[T]) => {} ) - - val p = this.partitioner - - new BlockRDD[T](sc, splits.map(getSplitKey).toArray) { - override val partitioner = p + /** + * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` + * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. + * This is used to truncate very long lineages. In the current implementation, Spark will save + * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. + * Hence, it is strongly recommended to use checkpoint() on RDDs when + * (i) Checkpoint() is called before the any job has been executed on this RDD. + * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will + * require recomputation. + */ + protected[spark] def checkpoint() { + synchronized { + if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) { + // do nothing + } else if (isCheckpointable) { + shouldCheckpoint = true + } else { + throw new Exception(this + " cannot be checkpointed") + } } } - + + /** + * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job + * using this RDD has completed (therefore the RDD has been materialized and + * potentially stored in memory). In case this RDD is not marked for checkpointing, + * doCheckpoint() is called recursively on the parent RDDs. + */ + private[spark] def doCheckpoint() { + val startCheckpoint = synchronized { + if (isCheckpointable && shouldCheckpoint && !isCheckpointInProgress) { + isCheckpointInProgress = true + true + } else { + false + } + } + + if (startCheckpoint) { + val rdd = this + val env = SparkEnv.get + + // Spawn a new thread to do the checkpoint as it takes sometime to write the RDD to file + val th = new Thread() { + override def run() { + // Save the RDD to a file, create a new HadoopRDD from it, + // and change the dependencies from the original parents to the new RDD + SparkEnv.set(env) + rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString + rdd.saveAsObjectFile(checkpointFile) + rdd.synchronized { + rdd.checkpointRDD = context.objectFile[T](checkpointFile) + rdd.checkpointRDDSplits = rdd.checkpointRDD.splits + rdd.changeDependencies(rdd.checkpointRDD) + rdd.shouldCheckpoint = false + rdd.isCheckpointInProgress = false + rdd.isCheckpointed = true + } + } + } + th.start() + } else { + // Recursively call doCheckpoint() to perform checkpointing on parent RDD if they are marked + dependencies.foreach(_.rdd.doCheckpoint()) + } + } + + /** + * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]] + * (`newRDD`) created from the checkpoint file. This method must ensure that all references + * to the original parent RDDs must be removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own changing + * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + */ + protected def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD)) + } + /** * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. * This should ''not'' be called by users directly, but is available for implementors of custom * subclasses of RDD. */ final def iterator(split: Split): Iterator[T] = { - if (storageLevel != StorageLevel.NONE) { + if (isCheckpointed) { + // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original + checkpointRDD.iterator(checkpointRDDSplits(split.index)) + } else if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { compute(split) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 6b957a6356..79ceab5f4f 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -188,6 +188,8 @@ class SparkContext( private var dagScheduler = new DAGScheduler(taskScheduler) + private[spark] var checkpointDir: String = null + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ @@ -519,6 +521,7 @@ class SparkContext( val start = System.nanoTime val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") + rdd.doCheckpoint() result } @@ -575,6 +578,24 @@ class SparkContext( return f } + /** + * Set the directory under which RDDs are going to be checkpointed. This method will + * create this directory and will throw an exception of the path already exists (to avoid + * overwriting existing files may be overwritten). The directory will be deleted on exit + * if indicated. + */ + def setCheckpointDir(dir: String, deleteOnExit: Boolean = false) { + val path = new Path(dir) + val fs = path.getFileSystem(new Configuration()) + if (fs.exists(path)) { + throw new Exception("Checkpoint directory '" + path + "' already exists.") + } else { + fs.mkdirs(path) + if (deleteOnExit) fs.deleteOnExit(path) + } + checkpointDir = dir + } + /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */ def defaultParallelism: Int = taskScheduler.defaultParallelism diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index cb73976aed..f4c3f99011 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -14,7 +14,7 @@ private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split private[spark] class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) - extends RDD[T](sc) { + extends RDD[T](sc, Nil) { @transient val splits_ = (0 until blockIds.size).map(i => { @@ -41,9 +41,12 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St } } - override def preferredLocations(split: Split) = - locations_(split.asInstanceOf[BlockRDDSplit].blockId) - - override val dependencies: List[Dependency[_]] = Nil + override def preferredLocations(split: Split) = { + if (isCheckpointed) { + checkpointRDD.preferredLocations(split) + } else { + locations_(split.asInstanceOf[BlockRDDSplit].blockId) + } + } } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index c97b835630..458ad38d55 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,9 +1,6 @@ package spark.rdd -import spark.NarrowDependency -import spark.RDD -import spark.SparkContext -import spark.Split +import spark._ import java.lang.ref.WeakReference private[spark] @@ -14,19 +11,15 @@ class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with private[spark] class CartesianRDD[T: ClassManifest, U:ClassManifest]( sc: SparkContext, - rdd1_ : WeakReference[RDD[T]], - rdd2_ : WeakReference[RDD[U]]) - extends RDD[Pair[T, U]](sc) + var rdd1 : RDD[T], + var rdd2 : RDD[U]) + extends RDD[Pair[T, U]](sc, Nil) with Serializable { - def rdd1 = rdd1_.get - def rdd2 = rdd2_.get - val numSplitsInRdd2 = rdd2.splits.size - // TODO: make this null when finishing checkpoint @transient - val splits_ = { + var splits_ = { // create the cross product split val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { @@ -36,12 +29,15 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( array } - // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override def preferredLocations(split: Split) = { - val currSplit = split.asInstanceOf[CartesianSplit] - rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) + if (isCheckpointed) { + checkpointRDD.preferredLocations(split) + } else { + val currSplit = split.asInstanceOf[CartesianSplit] + rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) + } } override def compute(split: Split) = { @@ -49,8 +45,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y) } - // TODO: make this null when finishing checkpoint - var deps = List( + var deps_ = List( new NarrowDependency(rdd1) { def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) }, @@ -59,5 +54,12 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( } ) - override def dependencies = deps + override def dependencies = deps_ + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + rdd1 = null + rdd2 = null + } } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index af54ac2fa0..a313ebcbe8 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -30,14 +30,13 @@ private[spark] class CoGroupAggregator { (b1, b2) => b1 ++ b2 }) with Serializable -class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) +class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging { val aggr = new CoGroupAggregator - // TODO: make this null when finishing checkpoint @transient - var deps = { + var deps_ = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) @@ -52,11 +51,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) deps.toList } - override def dependencies = deps + override def dependencies = deps_ - // TODO: make this null when finishing checkpoint @transient - val splits_ : Array[Split] = { + var splits_ : Array[Split] = { val firstRdd = rdds.head val array = new Array[Split](part.numPartitions) for (i <- 0 until array.size) { @@ -72,13 +70,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) array } - // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override val partitioner = Some(part) - override def preferredLocations(s: Split) = Nil - override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size @@ -106,4 +101,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) } map.iterator } + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + rdds = null + } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 573acf8893..5b5f72ddeb 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -1,8 +1,7 @@ package spark.rdd -import spark.NarrowDependency -import spark.RDD -import spark.Split +import spark._ +import java.lang.ref.WeakReference private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split @@ -15,13 +14,12 @@ private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) exten * or to avoid having a large number of small tasks when processing a directory with many files. */ class CoalescedRDD[T: ClassManifest]( - @transient prev: RDD[T], // TODO: Make this a weak reference + var prev: RDD[T], maxPartitions: Int) extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs - // TODO: make this null when finishing checkpoint - @transient val splits_ : Array[Split] = { - val prevSplits = firstParent[T].splits + @transient var splits_ : Array[Split] = { + val prevSplits = prev.splits if (prevSplits.length < maxPartitions) { prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) } } else { @@ -33,7 +31,6 @@ class CoalescedRDD[T: ClassManifest]( } } - // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override def compute(split: Split): Iterator[T] = { @@ -42,13 +39,18 @@ class CoalescedRDD[T: ClassManifest]( } } - // TODO: make this null when finishing checkpoint - var deps = List( - new NarrowDependency(firstParent) { + var deps_ : List[Dependency[_]] = List( + new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index) } ) - override def dependencies = deps + override def dependencies = deps_ + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD)) + splits_ = newRDD.splits + prev = null + } } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index cc2a3acd3a..1370cf6faf 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -7,7 +7,7 @@ import java.lang.ref.WeakReference private[spark] class FilteredRDD[T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: T => Boolean) extends RDD[T](prev.get) { diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 34bd784c13..6b2cc67568 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -7,7 +7,7 @@ import java.lang.ref.WeakReference private[spark] class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: T => TraversableOnce[U]) extends RDD[U](prev.get) { diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index 9321e89dcd..0f0b6ab0ff 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -6,7 +6,7 @@ import spark.Split import java.lang.ref.WeakReference private[spark] -class GlommedRDD[T: ClassManifest](@transient prev: WeakReference[RDD[T]]) +class GlommedRDD[T: ClassManifest](prev: WeakReference[RDD[T]]) extends RDD[Array[T]](prev.get) { override def splits = firstParent[T].splits override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index a12531ea89..19ed56d9c0 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -115,4 +115,6 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } + + override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index bad872c430..b04f56cfcc 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -7,7 +7,7 @@ import java.lang.ref.WeakReference private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false) extends RDD[U](prev.get) { diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index d7b238b05d..7a4b6ffb03 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -12,7 +12,7 @@ import java.lang.ref.WeakReference */ private[spark] class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: (Int, Iterator[T]) => Iterator[U]) extends RDD[U](prev.get) { diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 126c6f332b..8fa1872e0a 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -7,7 +7,7 @@ import java.lang.ref.WeakReference private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: T => U) extends RDD[U](prev.get) { diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index c12df5839e..2875abb2db 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -93,4 +93,6 @@ class NewHadoopRDD[K, V]( val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } + + override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index d54579d6d1..d9293a9d1a 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -12,6 +12,7 @@ import spark.OneToOneDependency import spark.RDD import spark.SparkEnv import spark.Split +import java.lang.ref.WeakReference /** @@ -19,16 +20,16 @@ import spark.Split * (printing them one per line) and returns the output as a collection of strings. */ class PipedRDD[T: ClassManifest]( - @transient prev: RDD[T], + prev: WeakReference[RDD[T]], command: Seq[String], envVars: Map[String, String]) - extends RDD[String](prev) { + extends RDD[String](prev.get) { - def this(@transient prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) + def this(prev: WeakReference[RDD[T]], command: Seq[String]) = this(prev, command, Map()) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(@transient prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) + def this(prev: WeakReference[RDD[T]], command: String) = this(prev, PipedRDD.tokenize(command)) override def splits = firstParent[T].splits diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 00b521b130..f273f257f8 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -15,7 +15,7 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali } class SampledRDD[T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], withReplacement: Boolean, frac: Double, seed: Int) diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 62867dab4f..b7d843c26d 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -31,11 +31,6 @@ class ShuffledRDD[K, V]( override def splits = splits_ - override def preferredLocations(split: Split) = Nil - - //val dep = new ShuffleDependency(parent, part) - //override val dependencies = List(dep) - override def compute(split: Split): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 0a61a2d1f5..643a174160 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -2,11 +2,7 @@ package spark.rdd import scala.collection.mutable.ArrayBuffer -import spark.Dependency -import spark.RangeDependency -import spark.RDD -import spark.SparkContext -import spark.Split +import spark._ import java.lang.ref.WeakReference private[spark] class UnionSplit[T: ClassManifest]( @@ -23,12 +19,11 @@ private[spark] class UnionSplit[T: ClassManifest]( class UnionRDD[T: ClassManifest]( sc: SparkContext, - @transient rdds: Seq[RDD[T]]) // TODO: Make this a weak reference + @transient var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs - // TODO: make this null when finishing checkpoint @transient - val splits_ : Array[Split] = { + var splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { @@ -38,11 +33,9 @@ class UnionRDD[T: ClassManifest]( array } - // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ - // TODO: make this null when finishing checkpoint - @transient var deps = { + @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { @@ -52,10 +45,21 @@ class UnionRDD[T: ClassManifest]( deps.toList } - override def dependencies = deps + override def dependencies = deps_ override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() - override def preferredLocations(s: Split): Seq[String] = - s.asInstanceOf[UnionSplit[T]].preferredLocations() + override def preferredLocations(s: Split): Seq[String] = { + if (isCheckpointed) { + checkpointRDD.preferredLocations(s) + } else { + s.asInstanceOf[UnionSplit[T]].preferredLocations() + } + } + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD)) + splits_ = newRDD.splits + rdds = null + } } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala new file mode 100644 index 0000000000..0e5ca7dc21 --- /dev/null +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -0,0 +1,116 @@ +package spark + +import org.scalatest.{BeforeAndAfter, FunSuite} +import java.io.File +import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD} +import spark.SparkContext._ +import storage.StorageLevel + +class CheckpointSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + var checkpointDir: File = _ + + before { + checkpointDir = File.createTempFile("temp", "") + checkpointDir.delete() + + sc = new SparkContext("local", "test") + sc.setCheckpointDir(checkpointDir.toString) + } + + after { + if (sc != null) { + sc.stop() + sc = null + } + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + + if (checkpointDir != null) { + checkpointDir.delete() + } + } + + test("ParallelCollection") { + val parCollection = sc.makeRDD(1 to 4) + parCollection.checkpoint() + assert(parCollection.dependencies === Nil) + val result = parCollection.collect() + sleep(parCollection) // slightly extra time as loading classes for the first can take some time + assert(sc.objectFile[Int](parCollection.checkpointFile).collect() === result) + assert(parCollection.dependencies != Nil) + assert(parCollection.collect() === result) + } + + test("BlockRDD") { + val blockId = "id" + val blockManager = SparkEnv.get.blockManager + blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) + val blockRDD = new BlockRDD[String](sc, Array(blockId)) + blockRDD.checkpoint() + val result = blockRDD.collect() + sleep(blockRDD) + assert(sc.objectFile[String](blockRDD.checkpointFile).collect() === result) + assert(blockRDD.dependencies != Nil) + assert(blockRDD.collect() === result) + } + + test("RDDs with one-to-one dependencies") { + testCheckpointing(_.map(x => x.toString)) + testCheckpointing(_.flatMap(x => 1 to x)) + testCheckpointing(_.filter(_ % 2 == 0)) + testCheckpointing(_.sample(false, 0.5, 0)) + testCheckpointing(_.glom()) + testCheckpointing(_.mapPartitions(_.map(_.toString))) + testCheckpointing(r => new MapPartitionsWithSplitRDD(r, + (i: Int, iter: Iterator[Int]) => iter.map(_.toString) )) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), 1000) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x), 1000) + testCheckpointing(_.pipe(Seq("cat"))) + } + + test("ShuffledRDD") { + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _)) + } + + test("UnionRDD") { + testCheckpointing(_.union(sc.makeRDD(5 to 6, 4))) + } + + test("CartesianRDD") { + testCheckpointing(_.cartesian(sc.makeRDD(5 to 6, 4)), 1000) + } + + test("CoalescedRDD") { + testCheckpointing(new CoalescedRDD(_, 2)) + } + + test("CoGroupedRDD") { + val rdd2 = sc.makeRDD(5 to 6, 4).map(x => (x % 2, 1)) + testCheckpointing(rdd1 => rdd1.map(x => (x % 2, 1)).cogroup(rdd2)) + testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2)) + } + + def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { + val parCollection = sc.makeRDD(1 to 4, 4) + val operatedRDD = op(parCollection) + operatedRDD.checkpoint() + val parentRDD = operatedRDD.dependencies.head.rdd + val result = operatedRDD.collect() + sleep(operatedRDD) + //println(parentRDD + ", " + operatedRDD.dependencies.head.rdd ) + assert(sc.objectFile[U](operatedRDD.checkpointFile).collect() === result) + assert(operatedRDD.dependencies.head.rdd != parentRDD) + assert(operatedRDD.collect() === result) + } + + def sleep(rdd: RDD[_]) { + val startTime = System.currentTimeMillis() + val maxWaitTime = 5000 + while(rdd.isCheckpointed == false && System.currentTimeMillis() < startTime + maxWaitTime) { + Thread.sleep(50) + } + assert(rdd.isCheckpointed === true, "Waiting for checkpoint to complete took more than " + maxWaitTime + " ms") + } +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 37a0ff0947..8ac7c8451a 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -19,7 +19,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.master.port") } - + test("basic operations") { sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) @@ -70,10 +70,23 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } - test("checkpointing") { + test("basic checkpointing") { + import java.io.File + val checkpointDir = File.createTempFile("temp", "") + checkpointDir.delete() + sc = new SparkContext("local", "test") - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint() - assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) + sc.setCheckpointDir(checkpointDir.toString) + val parCollection = sc.makeRDD(1 to 4) + val flatMappedRDD = parCollection.flatMap(x => 1 to x) + flatMappedRDD.checkpoint() + assert(flatMappedRDD.dependencies.head.rdd == parCollection) + val result = flatMappedRDD.collect() + Thread.sleep(1000) + assert(flatMappedRDD.dependencies.head.rdd != parCollection) + assert(flatMappedRDD.collect() === result) + + checkpointDir.deleteOnExit() } test("basic caching") { @@ -94,8 +107,8 @@ class RDDSuite extends FunSuite with BeforeAndAfter { List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) // Check that the narrow dependency is also specified correctly - assert(coalesced1.dependencies.head.getParents(0).toList === List(0, 1, 2, 3, 4)) - assert(coalesced1.dependencies.head.getParents(1).toList === List(5, 6, 7, 8, 9)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9)) val coalesced2 = new CoalescedRDD(data, 3) assert(coalesced2.collect().toList === (1 to 10).toList) -- cgit v1.2.3 From 34e569f40e184a6a4f21e9d79b0e8979d8f9541f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 31 Oct 2012 00:56:40 -0700 Subject: Added 'synchronized' to RDD serialization to ensure checkpoint-related changes are reflected atomically in the task closure. Added to tests to ensure that jobs running on an RDD on which checkpointing is in progress does hurt the result of the job. --- core/src/main/scala/spark/RDD.scala | 18 ++++++- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 7 ++- core/src/test/scala/spark/CheckpointSuite.scala | 71 ++++++++++++++++++++++++- 3 files changed, 92 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e272a0ede9..7b59a6f09e 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -1,8 +1,7 @@ package spark -import java.io.EOFException +import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream} import java.net.URL -import java.io.ObjectInputStream import java.util.concurrent.atomic.AtomicLong import java.util.Random import java.util.Date @@ -589,4 +588,19 @@ abstract class RDD[T: ClassManifest]( private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) } + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + synchronized { + oos.defaultWriteObject() + } + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + synchronized { + ois.defaultReadObject() + } + } + } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index b7d843c26d..31774585f4 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -27,7 +27,7 @@ class ShuffledRDD[K, V]( override val partitioner = Some(part) @transient - val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) + var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) override def splits = splits_ @@ -35,4 +35,9 @@ class ShuffledRDD[K, V]( val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } + + override def changeDependencies(newRDD: RDD[_]) { + dependencies_ = Nil + splits_ = newRDD.splits + } } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 0e5ca7dc21..57dc43ddac 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -5,8 +5,10 @@ import java.io.File import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD} import spark.SparkContext._ import storage.StorageLevel +import java.util.concurrent.Semaphore -class CheckpointSuite extends FunSuite with BeforeAndAfter { +class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { + initLogging() var sc: SparkContext = _ var checkpointDir: File = _ @@ -92,6 +94,35 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter { testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2)) } + /** + * This test forces two ResultTasks of the same job to be launched before and after + * the checkpointing of job's RDD is completed. + */ + test("Threading - ResultTasks") { + val op1 = (parCollection: RDD[Int]) => { + parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) + } + val op2 = (firstRDD: RDD[(Int, Int)]) => { + firstRDD.map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) + } + testThreading(op1, op2) + } + + /** + * This test forces two ShuffleMapTasks of the same job to be launched before and after + * the checkpointing of job's RDD is completed. + */ + test("Threading - ShuffleMapTasks") { + val op1 = (parCollection: RDD[Int]) => { + parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) + } + val op2 = (firstRDD: RDD[(Int, Int)]) => { + firstRDD.groupByKey(2).map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) + } + testThreading(op1, op2) + } + + def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { val parCollection = sc.makeRDD(1 to 4, 4) val operatedRDD = op(parCollection) @@ -105,6 +136,44 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter { assert(operatedRDD.collect() === result) } + def testThreading[U: ClassManifest, V: ClassManifest](op1: (RDD[Int]) => RDD[U], op2: (RDD[U]) => RDD[V]) { + + val parCollection = sc.makeRDD(1 to 2, 2) + + // This is the RDD that is to be checkpointed + val firstRDD = op1(parCollection) + val parentRDD = firstRDD.dependencies.head.rdd + firstRDD.checkpoint() + + // This the RDD that uses firstRDD. This is designed to launch a + // ShuffleMapTask that uses firstRDD. + val secondRDD = op2(firstRDD) + + // Starting first job, to initiate the checkpointing + logInfo("\nLaunching 1st job to initiate checkpointing\n") + firstRDD.collect() + + // Checkpointing has started but not completed yet + Thread.sleep(100) + assert(firstRDD.dependencies.head.rdd === parentRDD) + + // Starting second job; first task of this job will be + // launched _before_ firstRDD is marked as checkpointed + // and the second task will be launched _after_ firstRDD + // is marked as checkpointed + logInfo("\nLaunching 2nd job that is designed to launch tasks " + + "before and after checkpointing is complete\n") + val result = secondRDD.collect() + + // Check whether firstRDD has been successfully checkpointed + assert(firstRDD.dependencies.head.rdd != parentRDD) + + logInfo("\nRecomputing 2nd job to verify the results of the previous computation\n") + // Check whether the result in the previous job was correct or not + val correctResult = secondRDD.collect() + assert(result === correctResult) + } + def sleep(rdd: RDD[_]) { val startTime = System.currentTimeMillis() val maxWaitTime = 5000 -- cgit v1.2.3 From d1542387891018914fdd6b647f17f0b05acdd40e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 4 Nov 2012 12:12:06 -0800 Subject: Made checkpointing of dstream graph to work with checkpointing of RDDs. For streams requiring checkpointing of its RDD, the default checkpoint interval is set to 10 seconds. --- core/src/main/scala/spark/RDD.scala | 32 +++-- core/src/main/scala/spark/SparkContext.scala | 13 +- .../main/scala/spark/streaming/Checkpoint.scala | 83 ++++++----- .../src/main/scala/spark/streaming/DStream.scala | 156 ++++++++++++++++----- .../main/scala/spark/streaming/DStreamGraph.scala | 7 +- .../spark/streaming/ReducedWindowedDStream.scala | 36 +++-- .../main/scala/spark/streaming/StateDStream.scala | 45 +----- .../scala/spark/streaming/StreamingContext.scala | 38 +++-- .../src/main/scala/spark/streaming/Time.scala | 4 + .../examples/FileStreamWithCheckpoint.scala | 10 +- .../streaming/examples/TopKWordCountRaw.scala | 5 +- .../spark/streaming/examples/WordCount2.scala | 7 +- .../spark/streaming/examples/WordCountRaw.scala | 6 +- .../scala/spark/streaming/examples/WordMax2.scala | 10 +- .../scala/spark/streaming/CheckpointSuite.scala | 77 +++++++--- .../test/scala/spark/streaming/TestSuiteBase.scala | 37 +++-- .../spark/streaming/WindowOperationsSuite.scala | 4 +- 17 files changed, 367 insertions(+), 203 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7b59a6f09e..63048d5df0 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -119,22 +119,23 @@ abstract class RDD[T: ClassManifest]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Returns the first parent RDD */ - private[spark] def firstParent[U: ClassManifest] = { + protected[spark] def firstParent[U: ClassManifest] = { dependencies.head.rdd.asInstanceOf[RDD[U]] } /** Returns the `i` th parent RDD */ - private[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] + protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] // Variables relating to checkpointing - val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD - var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing - var isCheckpointInProgress = false // set to true when checkpointing is in progress - var isCheckpointed = false // set to true after checkpointing is completed + protected val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD - var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed - var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file - var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD + protected var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing + protected var isCheckpointInProgress = false // set to true when checkpointing is in progress + protected[spark] var isCheckpointed = false // set to true after checkpointing is completed + + protected[spark] var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed + protected var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file + protected var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD // Methods available on all RDDs: @@ -176,6 +177,9 @@ abstract class RDD[T: ClassManifest]( if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) { // do nothing } else if (isCheckpointable) { + if (sc.checkpointDir == null) { + throw new Exception("Checkpoint directory has not been set in the SparkContext.") + } shouldCheckpoint = true } else { throw new Exception(this + " cannot be checkpointed") @@ -183,6 +187,16 @@ abstract class RDD[T: ClassManifest]( } } + def getCheckpointData(): Any = { + synchronized { + if (isCheckpointed) { + checkpointFile + } else { + null + } + } + } + /** * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job * using this RDD has completed (therefore the RDD has been materialized and diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 79ceab5f4f..d7326971a9 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -584,14 +584,15 @@ class SparkContext( * overwriting existing files may be overwritten). The directory will be deleted on exit * if indicated. */ - def setCheckpointDir(dir: String, deleteOnExit: Boolean = false) { + def setCheckpointDir(dir: String, useExisting: Boolean = false) { val path = new Path(dir) val fs = path.getFileSystem(new Configuration()) - if (fs.exists(path)) { - throw new Exception("Checkpoint directory '" + path + "' already exists.") - } else { - fs.mkdirs(path) - if (deleteOnExit) fs.deleteOnExit(path) + if (!useExisting) { + if (fs.exists(path)) { + throw new Exception("Checkpoint directory '" + path + "' already exists.") + } else { + fs.mkdirs(path) + } } checkpointDir = dir } diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 83a43d15cb..cf04c7031e 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -1,6 +1,6 @@ package spark.streaming -import spark.Utils +import spark.{Logging, Utils} import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration @@ -8,13 +8,14 @@ import org.apache.hadoop.conf.Configuration import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} -class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Serializable { +class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) + extends Logging with Serializable { val master = ssc.sc.master val framework = ssc.sc.jobName val sparkHome = ssc.sc.sparkHome val jars = ssc.sc.jars val graph = ssc.graph - val checkpointFile = ssc.checkpointFile + val checkpointDir = ssc.checkpointDir val checkpointInterval = ssc.checkpointInterval validate() @@ -24,22 +25,25 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext assert(framework != null, "Checkpoint.framework is null") assert(graph != null, "Checkpoint.graph is null") assert(checkpointTime != null, "Checkpoint.checkpointTime is null") + logInfo("Checkpoint for time " + checkpointTime + " validated") } - def saveToFile(file: String = checkpointFile) { - val path = new Path(file) + def save(path: String) { + val file = new Path(path, "graph") val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (fs.exists(path)) { - val bkPath = new Path(path.getParent, path.getName + ".bk") - FileUtil.copy(fs, path, fs, bkPath, true, true, conf) - //logInfo("Moved existing checkpoint file to " + bkPath) + val fs = file.getFileSystem(conf) + logDebug("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'") + if (fs.exists(file)) { + val bkFile = new Path(file.getParent, file.getName + ".bk") + FileUtil.copy(fs, file, fs, bkFile, true, true, conf) + logDebug("Moved existing checkpoint file to " + bkFile) } - val fos = fs.create(path) + val fos = fs.create(file) val oos = new ObjectOutputStream(fos) oos.writeObject(this) oos.close() fs.close() + logInfo("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'") } def toBytes(): Array[Byte] = { @@ -50,30 +54,41 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext object Checkpoint { - def loadFromFile(file: String): Checkpoint = { - try { - val path = new Path(file) - val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (!fs.exists(path)) { - throw new Exception("Checkpoint file '" + file + "' does not exist") + def load(path: String): Checkpoint = { + + val fs = new Path(path).getFileSystem(new Configuration()) + val attempts = Seq(new Path(path), new Path(path, "graph"), new Path(path, "graph.bk")) + var lastException: Exception = null + var lastExceptionFile: String = null + + attempts.foreach(file => { + if (fs.exists(file)) { + try { + val fis = fs.open(file) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp.validate() + println("Checkpoint successfully loaded from file " + file) + return cp + } catch { + case e: Exception => + lastException = e + lastExceptionFile = file.toString + } } - val fis = fs.open(path) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() - cp.validate() - cp - } catch { - case e: Exception => - e.printStackTrace() - throw new Exception("Could not load checkpoint file '" + file + "'", e) + }) + + if (lastException == null) { + throw new Exception("Could not load checkpoint from path '" + path + "'") + } else { + throw new Exception("Error loading checkpoint from path '" + lastExceptionFile + "'", lastException) } } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index a4921bb1a2..de51c5d34a 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -13,6 +13,7 @@ import scala.collection.mutable.HashMap import java.util.concurrent.ArrayBlockingQueue import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import scala.Some +import collection.mutable abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) extends Serializable with Logging { @@ -41,53 +42,55 @@ extends Serializable with Logging { */ // RDDs generated, marked as protected[streaming] so that testsuites can access it - protected[streaming] val generatedRDDs = new HashMap[Time, RDD[T]] () + protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () // Time zero for the DStream - protected var zeroTime: Time = null + protected[streaming] var zeroTime: Time = null // Duration for which the DStream will remember each RDD created - protected var rememberDuration: Time = null + protected[streaming] var rememberDuration: Time = null // Storage level of the RDDs in the stream - protected var storageLevel: StorageLevel = StorageLevel.NONE + protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE - // Checkpoint level and checkpoint interval - protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint - protected var checkpointInterval: Time = null + // Checkpoint details + protected[streaming] val mustCheckpoint = false + protected[streaming] var checkpointInterval: Time = null + protected[streaming] val checkpointData = new HashMap[Time, Any]() // Reference to whole DStream graph - protected var graph: DStreamGraph = null + protected[streaming] var graph: DStreamGraph = null def isInitialized = (zeroTime != null) // Duration for which the DStream requires its parent DStream to remember each RDD created def parentRememberDuration = rememberDuration - // Change this RDD's storage level - def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): DStream[T] = { - if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) { - // TODO: not sure this is necessary for DStreams + // Set caching level for the RDDs created by this DStream + def persist(level: StorageLevel): DStream[T] = { + if (this.isInitialized) { throw new UnsupportedOperationException( - "Cannot change storage level of an DStream after it was already assigned a level") + "Cannot change storage level of an DStream after streaming context has started") } - this.storageLevel = storageLevel - this.checkpointLevel = checkpointLevel - this.checkpointInterval = checkpointInterval + this.storageLevel = level this } - // Set caching level for the RDDs created by this DStream - def persist(newLevel: StorageLevel): DStream[T] = persist(newLevel, StorageLevel.NONE, null) - def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY) // Turn on the default caching level for this RDD def cache(): DStream[T] = persist() + def checkpoint(interval: Time): DStream[T] = { + if (isInitialized) { + throw new UnsupportedOperationException( + "Cannot change checkpoint interval of an DStream after streaming context has started") + } + persist() + checkpointInterval = interval + this + } + /** * This method initializes the DStream by setting the "zero" time, based on which * the validity of future times is calculated. This method also recursively initializes @@ -99,7 +102,67 @@ extends Serializable with Logging { + ", cannot initialize it again to " + time) } zeroTime = time + + // Set the checkpoint interval to be slideTime or 10 seconds, which ever is larger + if (mustCheckpoint && checkpointInterval == null) { + checkpointInterval = slideTime.max(Seconds(10)) + logInfo("Checkpoint interval automatically set to " + checkpointInterval) + } + + // Set the minimum value of the rememberDuration if not already set + var minRememberDuration = slideTime + if (checkpointInterval != null && minRememberDuration <= checkpointInterval) { + minRememberDuration = checkpointInterval + slideTime + } + if (rememberDuration == null || rememberDuration < minRememberDuration) { + rememberDuration = minRememberDuration + } + + // Initialize the dependencies dependencies.foreach(_.initialize(zeroTime)) + } + + protected[streaming] def validate() { + assert( + !mustCheckpoint || checkpointInterval != null, + "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " + + " Please use DStream.checkpoint() to set the interval." + ) + + assert( + checkpointInterval == null || checkpointInterval >= slideTime, + "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + + checkpointInterval + " which is lower than its slide time (" + slideTime + "). " + + "Please set it to at least " + slideTime + "." + ) + + assert( + checkpointInterval == null || checkpointInterval.isMultipleOf(slideTime), + "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + + checkpointInterval + " which not a multiple of its slide time (" + slideTime + "). " + + "Please set it to a multiple " + slideTime + "." + ) + + assert( + checkpointInterval == null || storageLevel != StorageLevel.NONE, + "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + + "level has not been set to enable persisting. Please use DStream.persist() to set the " + + "storage level to use memory for better checkpointing performance." + ) + + assert( + checkpointInterval == null || rememberDuration > checkpointInterval, + "The remember duration for " + this.getClass.getSimpleName + " has been set to " + + rememberDuration + " which is not more than the checkpoint interval (" + + checkpointInterval + "). Please set it to higher than " + checkpointInterval + "." + ) + + dependencies.foreach(_.validate()) + + logInfo("Slide time = " + slideTime) + logInfo("Storage level = " + storageLevel) + logInfo("Checkpoint interval = " + checkpointInterval) + logInfo("Remember duration = " + rememberDuration) logInfo("Initialized " + this) } @@ -120,17 +183,12 @@ extends Serializable with Logging { dependencies.foreach(_.setGraph(graph)) } - protected[streaming] def setRememberDuration(duration: Time = slideTime) { - if (duration == null) { - throw new Exception("Duration for remembering RDDs cannot be set to null for " + this) - } else if (rememberDuration != null && duration < rememberDuration) { - logWarning("Duration for remembering RDDs cannot be reduced from " + rememberDuration - + " to " + duration + " for " + this) - } else { + protected[streaming] def setRememberDuration(duration: Time) { + if (duration != null && duration > rememberDuration) { rememberDuration = duration - dependencies.foreach(_.setRememberDuration(parentRememberDuration)) logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) } + dependencies.foreach(_.setRememberDuration(parentRememberDuration)) } /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ @@ -163,12 +221,13 @@ extends Serializable with Logging { if (isTimeValid(time)) { compute(time) match { case Some(newRDD) => - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { + if (storageLevel != StorageLevel.NONE) { newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) + logInfo("Persisting RDD for time " + time + " to " + storageLevel + " at time " + time) + } + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.checkpoint() + logInfo("Marking RDD for time " + time + " for checkpointing at time " + time) } generatedRDDs.put(time, newRDD) Some(newRDD) @@ -199,7 +258,7 @@ extends Serializable with Logging { } } - def forgetOldRDDs(time: Time) { + protected[streaming] def forgetOldRDDs(time: Time) { val keys = generatedRDDs.keys var numForgotten = 0 keys.foreach(t => { @@ -213,12 +272,35 @@ extends Serializable with Logging { dependencies.foreach(_.forgetOldRDDs(time)) } + protected[streaming] def updateCheckpointData() { + checkpointData.clear() + generatedRDDs.foreach { + case(time, rdd) => { + logDebug("Adding checkpointed RDD for time " + time) + val data = rdd.getCheckpointData() + if (data != null) { + checkpointData += ((time, data)) + } + } + } + } + + protected[streaming] def restoreCheckpointData() { + checkpointData.foreach { + case(time, data) => { + logInfo("Restoring checkpointed RDD for time " + time) + generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString))) + } + } + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { logDebug(this.getClass().getSimpleName + ".writeObject used") if (graph != null) { graph.synchronized { if (graph.checkpointInProgress) { + updateCheckpointData() oos.defaultWriteObject() } else { val msg = "Object of " + this.getClass.getName + " is being serialized " + @@ -239,6 +321,8 @@ extends Serializable with Logging { private def readObject(ois: ObjectInputStream) { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() + generatedRDDs = new HashMap[Time, RDD[T]] () + restoreCheckpointData() } /** diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index ac44d7a2a6..f8922ec790 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -22,11 +22,8 @@ final class DStreamGraph extends Serializable with Logging { } zeroTime = time outputStreams.foreach(_.initialize(zeroTime)) - outputStreams.foreach(_.setRememberDuration()) // first set the rememberDuration to default values - if (rememberDuration != null) { - // if custom rememberDuration has been provided, set the rememberDuration - outputStreams.foreach(_.setRememberDuration(rememberDuration)) - } + outputStreams.foreach(_.setRememberDuration(rememberDuration)) + outputStreams.foreach(_.validate) inputStreams.par.foreach(_.start()) } } diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index 1c57d5f855..6df82c0df3 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -21,15 +21,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( partitioner: Partitioner ) extends DStream[(K,V)](parent.ssc) { - if (!_windowTime.isMultipleOf(parent.slideTime)) - throw new Exception("The window duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + assert(_windowTime.isMultipleOf(parent.slideTime), + "The window duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) - if (!_slideTime.isMultipleOf(parent.slideTime)) - throw new Exception("The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + assert(_slideTime.isMultipleOf(parent.slideTime), + "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) - @transient val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + super.persist(StorageLevel.MEMORY_ONLY) + + val reducedStream = parent.reduceByKey(reduceFunc, partitioner) def windowTime: Time = _windowTime @@ -37,15 +41,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( override def slideTime: Time = _slideTime - //TODO: This is wrong. This should depend on the checkpointInterval + override val mustCheckpoint = true + override def parentRememberDuration: Time = rememberDuration + windowTime - override def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): DStream[(K,V)] = { - super.persist(storageLevel, checkpointLevel, checkpointInterval) - reducedStream.persist(storageLevel, checkpointLevel, checkpointInterval) + override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + super.persist(storageLevel) + reducedStream.persist(storageLevel) + this + } + + override def checkpoint(interval: Time): DStream[(K, V)] = { + super.checkpoint(interval) + reducedStream.checkpoint(interval) this } diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index 086752ac55..0211df1343 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -23,51 +23,14 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife rememberPartitioner: Boolean ) extends DStream[(K, S)](parent.ssc) { + super.persist(StorageLevel.MEMORY_ONLY) + override def dependencies = List(parent) override def slideTime = parent.slideTime - override def getOrCompute(time: Time): Option[RDD[(K, S)]] = { - generatedRDDs.get(time) match { - case Some(oldRDD) => { - if (checkpointInterval != null && time > zeroTime && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) { - val r = oldRDD - val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index) - val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) { - override val partitioner = oldRDD.partitioner - } - generatedRDDs.update(time, checkpointedRDD) - logInfo("Checkpointed RDD " + oldRDD.id + " of time " + time + " with its new RDD " + checkpointedRDD.id) - Some(checkpointedRDD) - } else { - Some(oldRDD) - } - } - case None => { - if (isTimeValid(time)) { - compute(time) match { - case Some(newRDD) => { - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) - } - generatedRDDs.put(time, newRDD) - Some(newRDD) - } - case None => { - None - } - } - } else { - None - } - } - } - } - + override val mustCheckpoint = true + override def compute(validTime: Time): Option[RDD[(K, S)]] = { // Try to get the previous state RDD diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index b3148eaa97..3838e84113 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -15,6 +15,8 @@ import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.hadoop.fs.Path +import java.util.UUID class StreamingContext ( sc_ : SparkContext, @@ -26,7 +28,7 @@ class StreamingContext ( def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = this(new SparkContext(master, frameworkName, sparkHome, jars), null) - def this(file: String) = this(null, Checkpoint.loadFromFile(file)) + def this(path: String) = this(null, Checkpoint.load(path)) def this(cp_ : Checkpoint) = this(null, cp_) @@ -51,7 +53,6 @@ class StreamingContext ( val graph: DStreamGraph = { if (isCheckpointPresent) { - cp_.graph.setContext(this) cp_.graph } else { @@ -62,7 +63,15 @@ class StreamingContext ( val nextNetworkInputStreamId = new AtomicInteger(0) var networkInputTracker: NetworkInputTracker = null - private[streaming] var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null + private[streaming] var checkpointDir: String = { + if (isCheckpointPresent) { + sc.setCheckpointDir(cp_.checkpointDir, true) + cp_.checkpointDir + } else { + null + } + } + private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null private[streaming] var receiverJobThread: Thread = null private[streaming] var scheduler: Scheduler = null @@ -75,9 +84,15 @@ class StreamingContext ( graph.setRememberDuration(duration) } - def setCheckpointDetails(file: String, interval: Time) { - checkpointFile = file - checkpointInterval = interval + def checkpoint(dir: String, interval: Time) { + if (dir != null) { + sc.setCheckpointDir(new Path(dir, "rdds-" + UUID.randomUUID.toString).toString) + checkpointDir = dir + checkpointInterval = interval + } else { + checkpointDir = null + checkpointInterval = null + } } private[streaming] def getInitialCheckpoint(): Checkpoint = { @@ -170,16 +185,12 @@ class StreamingContext ( graph.addOutputStream(outputStream) } - def validate() { - assert(graph != null, "Graph is null") - graph.validate() - } - /** * This function starts the execution of the streams. */ def start() { - validate() + assert(graph != null, "Graph is null") + graph.validate() val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true @@ -216,7 +227,8 @@ class StreamingContext ( } def doCheckpoint(currentTime: Time) { - new Checkpoint(this, currentTime).saveToFile(checkpointFile) + new Checkpoint(this, currentTime).save(checkpointDir) + } } diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 9ddb65249a..2ba6502971 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -25,6 +25,10 @@ case class Time(millis: Long) { def isMultipleOf(that: Time): Boolean = (this.millis % that.millis == 0) + def min(that: Time): Time = if (this < that) this else that + + def max(that: Time): Time = if (this > that) this else that + def isZero: Boolean = (this.millis == 0) override def toString: String = (millis.toString + " ms") diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala index df96a811da..21a83c0fde 100644 --- a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala +++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala @@ -10,20 +10,20 @@ object FileStreamWithCheckpoint { def main(args: Array[String]) { if (args.size != 3) { - println("FileStreamWithCheckpoint ") - println("FileStreamWithCheckpoint restart ") + println("FileStreamWithCheckpoint ") + println("FileStreamWithCheckpoint restart ") System.exit(-1) } val directory = new Path(args(1)) - val checkpointFile = args(2) + val checkpointDir = args(2) val ssc: StreamingContext = { if (args(0) == "restart") { // Recreated streaming context from specified checkpoint file - new StreamingContext(checkpointFile) + new StreamingContext(checkpointDir) } else { @@ -34,7 +34,7 @@ object FileStreamWithCheckpoint { // Create new streaming context val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint") ssc_.setBatchDuration(Seconds(1)) - ssc_.setCheckpointDetails(checkpointFile, Seconds(1)) + ssc_.checkpoint(checkpointDir, Seconds(1)) // Setup the streaming computation val inputStream = ssc_.textFileStream(directory.toString) diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index 57fd10f0a5..750cb7445f 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -41,9 +41,8 @@ object TopKWordCountRaw { val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMs)) - //windowedCounts.print() // TODO: something else? + windowedCounts.persist().checkpoint(Milliseconds(chkptMs)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs)) def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = { val taken = new Array[(String, Long)](k) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index 0d2e62b955..865026033e 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -100,10 +100,9 @@ object WordCount2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, - StorageLevel.MEMORY_ONLY_2, - //new StorageLevel(false, true, true, 3), - Milliseconds(chkptMillis.toLong)) + + windowedCounts.persist().checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index abfd12890f..d1ea9a9cd5 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -41,9 +41,9 @@ object WordCountRaw { val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMs)) - //windowedCounts.print() // TODO: something else? + windowedCounts.persist().checkpoint(chkptMs) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs)) + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala index 9d44da2b11..6a9c8a9a69 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala @@ -57,11 +57,13 @@ object WordMax2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) .reduceByKey(add _, reduceTasks.toInt) - .persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMillis.toLong)) + .persist() + .checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) .reduceByKeyAndWindow(max _, Seconds(10), batchDuration, reduceTasks.toInt) - //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - // Milliseconds(chkptMillis.toLong)) + .persist() + .checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 6dcedcf463..dfe31b5771 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -2,52 +2,95 @@ package spark.streaming import spark.streaming.StreamingContext._ import java.io.File +import collection.mutable.ArrayBuffer +import runtime.RichInt +import org.scalatest.BeforeAndAfter +import org.apache.hadoop.fs.Path +import org.apache.commons.io.FileUtils -class CheckpointSuite extends TestSuiteBase { +class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { + + before { + FileUtils.deleteDirectory(new File(checkpointDir)) + } + + after { + FileUtils.deleteDirectory(new File(checkpointDir)) + } override def framework() = "CheckpointSuite" - override def checkpointFile() = "checkpoint" + override def batchDuration() = Seconds(1) + + override def checkpointDir() = "checkpoint" + + override def checkpointInterval() = batchDuration def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], expectedOutput: Seq[Seq[V]], - useSet: Boolean = false + initialNumBatches: Int ) { // Current code assumes that: // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val initialNumBatches = input.size / 2 val nextNumBatches = totalNumBatches - initialNumBatches val initialNumExpectedOutputs = initialNumBatches + val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs // Do half the computation (half the number of batches), create checkpoint file and quit val ssc = setupStreams[U, V](input, operation) val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) - verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet) + verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) Thread.sleep(1000) // Restart and complete the computation from checkpoint file - val sscNew = new StreamingContext(checkpointFile) - sscNew.setCheckpointDetails(null, null) - val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size) - verifyOutput[V](outputNew, expectedOutput, useSet) - - new File(checkpointFile).delete() - new File(checkpointFile + ".bk").delete() - new File("." + checkpointFile + ".crc").delete() - new File("." + checkpointFile + ".bk.crc").delete() + val sscNew = new StreamingContext(checkpointDir) + //sscNew.checkpoint(null, null) + val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs) + verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) } - test("simple per-batch operation") { + + test("map and reduceByKey") { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), - true + 3 ) } + + test("reduceByKeyAndWindowInv") { + val n = 10 + val w = 4 + val input = (1 to n).map(x => Seq("a")).toSeq + val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) + val operation = (st: DStream[String]) => { + st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, Seconds(w), Seconds(1)) + } + for (i <- Seq(3, 5, 7)) { + testCheckpointedOperation(input, operation, output, i) + } + } + + test("updateStateByKey") { + val input = (1 to 10).map(_ => Seq("a")).toSeq + val output = (1 to 10).map(x => Seq(("a", x))).toSeq + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + st.map(x => (x, 1)) + .updateStateByKey[RichInt](updateFunc) + .checkpoint(Seconds(5)) + .map(t => (t._1, t._2.self)) + } + for (i <- Seq(3, 5, 7)) { + testCheckpointedOperation(input, operation, output, i) + } + } + } \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index c9bc454f91..e441feea19 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -5,10 +5,16 @@ import util.ManualClock import collection.mutable.ArrayBuffer import org.scalatest.FunSuite import collection.mutable.SynchronizedBuffer +import java.io.{ObjectInputStream, IOException} + +/** + * This is a input stream just for the testsuites. This is equivalent to a checkpointable, + * replayable, reliable message queue like Kafka. It requires a sequence as input, and + * returns the i_th element at the i_th batch unde manual clock. + */ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int) extends InputDStream[T](ssc_) { - var currentIndex = 0 def start() {} @@ -23,17 +29,32 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ ssc.sc.makeRDD(Seq[T](), numPartitions) } logInfo("Created RDD " + rdd.id) - //currentIndex += 1 Some(rdd) } } +/** + * This is a output stream just for the testsuites. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + */ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected - }) + }) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + output.clear() + } +} +/** + * This is the base trait for Spark Streaming testsuites. This provides basic functionality + * to run user-defined set of input on user-defined stream operations, and verify the output. + */ trait TestSuiteBase extends FunSuite with Logging { System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") @@ -44,7 +65,7 @@ trait TestSuiteBase extends FunSuite with Logging { def batchDuration() = Seconds(1) - def checkpointFile() = null.asInstanceOf[String] + def checkpointDir() = null.asInstanceOf[String] def checkpointInterval() = batchDuration @@ -60,8 +81,8 @@ trait TestSuiteBase extends FunSuite with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - if (checkpointFile != null) { - ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + if (checkpointDir != null) { + ssc.checkpoint(checkpointDir, checkpointInterval()) } // Setup the stream computation @@ -82,8 +103,8 @@ trait TestSuiteBase extends FunSuite with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - if (checkpointFile != null) { - ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + if (checkpointDir != null) { + ssc.checkpoint(checkpointDir, checkpointInterval()) } // Setup the stream computation diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index d7d8d5bd36..e282f0fdd5 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -283,7 +283,9 @@ class WindowOperationsSuite extends TestSuiteBase { test("reduceByKeyAndWindowInv - " + name) { val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist() + s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime) + .persist() + .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing } testOperation(input, operation, expectedOutput, numBatches, true) } -- cgit v1.2.3 From 72b2303f99bd652fc4bdaa929f37731a7ba8f640 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 5 Nov 2012 11:41:36 -0800 Subject: Fixed major bugs in checkpointing. --- core/src/main/scala/spark/SparkContext.scala | 6 +- .../main/scala/spark/streaming/Checkpoint.scala | 24 ++-- .../src/main/scala/spark/streaming/DStream.scala | 47 +++++-- .../main/scala/spark/streaming/DStreamGraph.scala | 36 ++++-- .../src/main/scala/spark/streaming/Scheduler.scala | 1 - .../scala/spark/streaming/StreamingContext.scala | 8 +- .../scala/spark/streaming/CheckpointSuite.scala | 139 ++++++++++++++++----- .../test/scala/spark/streaming/TestSuiteBase.scala | 37 ++++-- .../spark/streaming/WindowOperationsSuite.scala | 6 +- 9 files changed, 217 insertions(+), 87 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d7326971a9..d7b46bee38 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -474,8 +474,10 @@ class SparkContext( /** Shut down the SparkContext. */ def stop() { - dagScheduler.stop() - dagScheduler = null + if (dagScheduler != null) { + dagScheduler.stop() + dagScheduler = null + } taskScheduler = null // TODO: Cache.stop()? env.stop() diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index cf04c7031e..6b4b05103f 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -6,6 +6,7 @@ import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} +import sys.process.processInternal class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) @@ -52,17 +53,17 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) } } -object Checkpoint { +object Checkpoint extends Logging { def load(path: String): Checkpoint = { val fs = new Path(path).getFileSystem(new Configuration()) - val attempts = Seq(new Path(path), new Path(path, "graph"), new Path(path, "graph.bk")) - var lastException: Exception = null - var lastExceptionFile: String = null + val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk")) + var detailedLog: String = "" attempts.foreach(file => { if (fs.exists(file)) { + logInfo("Attempting to load checkpoint from file '" + file + "'") try { val fis = fs.open(file) // ObjectInputStream uses the last defined user-defined class loader in the stack @@ -75,21 +76,18 @@ object Checkpoint { ois.close() fs.close() cp.validate() - println("Checkpoint successfully loaded from file " + file) + logInfo("Checkpoint successfully loaded from file '" + file + "'") return cp } catch { case e: Exception => - lastException = e - lastExceptionFile = file.toString + logError("Error loading checkpoint from file '" + file + "'", e) } + } else { + logWarning("Could not load checkpoint from file '" + file + "' as it does not exist") } - }) - if (lastException == null) { - throw new Exception("Could not load checkpoint from path '" + path + "'") - } else { - throw new Exception("Error loading checkpoint from path '" + lastExceptionFile + "'", lastException) - } + }) + throw new Exception("Could not load checkpoint from path '" + path + "'") } def fromBytes(bytes: Array[Byte]): Checkpoint = { diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index de51c5d34a..2fecbe0acf 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -14,6 +14,8 @@ import java.util.concurrent.ArrayBlockingQueue import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import scala.Some import collection.mutable +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) extends Serializable with Logging { @@ -42,6 +44,7 @@ extends Serializable with Logging { */ // RDDs generated, marked as protected[streaming] so that testsuites can access it + @transient protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () // Time zero for the DStream @@ -112,7 +115,7 @@ extends Serializable with Logging { // Set the minimum value of the rememberDuration if not already set var minRememberDuration = slideTime if (checkpointInterval != null && minRememberDuration <= checkpointInterval) { - minRememberDuration = checkpointInterval + slideTime + minRememberDuration = checkpointInterval * 2 // times 2 just to be sure that the latest checkpoint is not forgetten } if (rememberDuration == null || rememberDuration < minRememberDuration) { rememberDuration = minRememberDuration @@ -265,33 +268,59 @@ extends Serializable with Logging { if (t <= (time - rememberDuration)) { generatedRDDs.remove(t) numForgotten += 1 - //logInfo("Forgot RDD of time " + t + " from " + this) + logInfo("Forgot RDD of time " + t + " from " + this) } }) logInfo("Forgot " + numForgotten + " RDDs from " + this) dependencies.foreach(_.forgetOldRDDs(time)) } + /** + * Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of this stream. + * Along with that it forget old checkpoint files. + */ protected[streaming] def updateCheckpointData() { + + // TODO (tdas): This code can be simplified. Its kept verbose to aid debugging. + val checkpointedRDDs = generatedRDDs.filter(_._2.getCheckpointData() != null) + val removedCheckpointData = checkpointData.filter(x => !generatedRDDs.contains(x._1)) + checkpointData.clear() - generatedRDDs.foreach { - case(time, rdd) => { - logDebug("Adding checkpointed RDD for time " + time) + checkpointedRDDs.foreach { + case (time, rdd) => { val data = rdd.getCheckpointData() - if (data != null) { - checkpointData += ((time, data)) + assert(data != null) + checkpointData += ((time, data)) + logInfo("Added checkpointed RDD " + rdd + " for time " + time + " to stream checkpoint") + } + } + + dependencies.foreach(_.updateCheckpointData()) + // If at least one checkpoint is present, then delete old checkpoints + if (checkpointData.size > 0) { + // Delete the checkpoint RDD files that are not needed any more + removedCheckpointData.foreach { + case (time: Time, file: String) => { + val path = new Path(file) + val fs = path.getFileSystem(new Configuration()) + fs.delete(path, true) + logInfo("Deleted checkpoint file '" + file + "' for time " + time) } } } + + logInfo("Updated checkpoint data") } protected[streaming] def restoreCheckpointData() { + logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs") checkpointData.foreach { case(time, data) => { - logInfo("Restoring checkpointed RDD for time " + time) + logInfo("Restoring checkpointed RDD for time " + time + " from file") generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString))) } } + dependencies.foreach(_.restoreCheckpointData()) } @throws(classOf[IOException]) @@ -300,7 +329,6 @@ extends Serializable with Logging { if (graph != null) { graph.synchronized { if (graph.checkpointInProgress) { - updateCheckpointData() oos.defaultWriteObject() } else { val msg = "Object of " + this.getClass.getName + " is being serialized " + @@ -322,7 +350,6 @@ extends Serializable with Logging { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() generatedRDDs = new HashMap[Time, RDD[T]] () - restoreCheckpointData() } /** diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index f8922ec790..7437f4402d 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -4,7 +4,7 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import collection.mutable.ArrayBuffer import spark.Logging -final class DStreamGraph extends Serializable with Logging { +final private[streaming] class DStreamGraph extends Serializable with Logging { initLogging() private val inputStreams = new ArrayBuffer[InputDStream[_]]() @@ -15,7 +15,7 @@ final class DStreamGraph extends Serializable with Logging { private[streaming] var rememberDuration: Time = null private[streaming] var checkpointInProgress = false - def start(time: Time) { + private[streaming] def start(time: Time) { this.synchronized { if (zeroTime != null) { throw new Exception("DStream graph computation already started") @@ -28,7 +28,7 @@ final class DStreamGraph extends Serializable with Logging { } } - def stop() { + private[streaming] def stop() { this.synchronized { inputStreams.par.foreach(_.stop()) } @@ -40,7 +40,7 @@ final class DStreamGraph extends Serializable with Logging { } } - def setBatchDuration(duration: Time) { + private[streaming] def setBatchDuration(duration: Time) { this.synchronized { if (batchDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + @@ -50,7 +50,7 @@ final class DStreamGraph extends Serializable with Logging { batchDuration = duration } - def setRememberDuration(duration: Time) { + private[streaming] def setRememberDuration(duration: Time) { this.synchronized { if (rememberDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + @@ -60,37 +60,49 @@ final class DStreamGraph extends Serializable with Logging { rememberDuration = duration } - def addInputStream(inputStream: InputDStream[_]) { + private[streaming] def addInputStream(inputStream: InputDStream[_]) { this.synchronized { inputStream.setGraph(this) inputStreams += inputStream } } - def addOutputStream(outputStream: DStream[_]) { + private[streaming] def addOutputStream(outputStream: DStream[_]) { this.synchronized { outputStream.setGraph(this) outputStreams += outputStream } } - def getInputStreams() = inputStreams.toArray + private[streaming] def getInputStreams() = this.synchronized { inputStreams.toArray } - def getOutputStreams() = outputStreams.toArray + private[streaming] def getOutputStreams() = this.synchronized { outputStreams.toArray } - def generateRDDs(time: Time): Seq[Job] = { + private[streaming] def generateRDDs(time: Time): Seq[Job] = { this.synchronized { outputStreams.flatMap(outputStream => outputStream.generateJob(time)) } } - def forgetOldRDDs(time: Time) { + private[streaming] def forgetOldRDDs(time: Time) { this.synchronized { outputStreams.foreach(_.forgetOldRDDs(time)) } } - def validate() { + private[streaming] def updateCheckpointData() { + this.synchronized { + outputStreams.foreach(_.updateCheckpointData()) + } + } + + private[streaming] def restoreCheckpointData() { + this.synchronized { + outputStreams.foreach(_.restoreCheckpointData()) + } + } + + private[streaming] def validate() { this.synchronized { assert(batchDuration != null, "Batch duration has not been set") assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low") diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 7d52e2eddf..2b3f5a4829 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -58,7 +58,6 @@ extends Logging { graph.forgetOldRDDs(time) if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { ssc.doCheckpoint(time) - logInfo("Checkpointed at time " + time) } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 3838e84113..fb36ab9dc9 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -54,6 +54,7 @@ class StreamingContext ( val graph: DStreamGraph = { if (isCheckpointPresent) { cp_.graph.setContext(this) + cp_.graph.restoreCheckpointData() cp_.graph } else { new DStreamGraph() @@ -218,17 +219,16 @@ class StreamingContext ( if (scheduler != null) scheduler.stop() if (networkInputTracker != null) networkInputTracker.stop() if (receiverJobThread != null) receiverJobThread.interrupt() - sc.stop() + sc.stop() + logInfo("StreamingContext stopped successfully") } catch { case e: Exception => logWarning("Error while stopping", e) } - - logInfo("StreamingContext stopped") } def doCheckpoint(currentTime: Time) { + graph.updateCheckpointData() new Checkpoint(this, currentTime).save(checkpointDir) - } } diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index dfe31b5771..aa8ded513c 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -2,11 +2,11 @@ package spark.streaming import spark.streaming.StreamingContext._ import java.io.File -import collection.mutable.ArrayBuffer import runtime.RichInt import org.scalatest.BeforeAndAfter -import org.apache.hadoop.fs.Path import org.apache.commons.io.FileUtils +import collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import util.ManualClock class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { @@ -18,39 +18,83 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { FileUtils.deleteDirectory(new File(checkpointDir)) } - override def framework() = "CheckpointSuite" + override def framework = "CheckpointSuite" - override def batchDuration() = Seconds(1) + override def batchDuration = Milliseconds(500) - override def checkpointDir() = "checkpoint" + override def checkpointDir = "checkpoint" - override def checkpointInterval() = batchDuration + override def checkpointInterval = batchDuration - def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - initialNumBatches: Int - ) { + override def actuallyWait = true - // Current code assumes that: - // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val nextNumBatches = totalNumBatches - initialNumBatches - val initialNumExpectedOutputs = initialNumBatches - val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + test("basic stream+rdd recovery") { - // Do half the computation (half the number of batches), create checkpoint file and quit - val ssc = setupStreams[U, V](input, operation) - val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) - verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) - Thread.sleep(1000) + assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - // Restart and complete the computation from checkpoint file + val checkpointingInterval = Seconds(2) + + // this ensure checkpointing occurs at least once + val firstNumBatches = (checkpointingInterval.millis / batchDuration.millis) * 2 + val secondNumBatches = firstNumBatches + + // Setup the streams + val input = (1 to 10).map(_ => Seq("a")).toSeq + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + st.map(x => (x, 1)) + .updateStateByKey[RichInt](updateFunc) + .checkpoint(checkpointingInterval) + .map(t => (t._1, t._2.self)) + } + val ssc = setupStreams(input, operation) + val stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head + + // Run till a time such that at least one RDD in the stream should have been checkpointed + ssc.start() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + logInfo("Manual clock before advancing = " + clock.time) + for (i <- 1 to firstNumBatches.toInt) { + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + logInfo("Manual clock after advancing = " + clock.time) + Thread.sleep(batchDuration.milliseconds) + + // Check whether some RDD has been checkpointed or not + logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.mkString(",\n") + "]") + assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream") + stateStream.checkpointData.foreach { + case (time, data) => { + val file = new File(data.toString) + assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " does not exist") + } + } + val checkpointFiles = stateStream.checkpointData.map(x => new File(x._2.toString)) + + // Run till a further time such that previous checkpoint files in the stream would be deleted + logInfo("Manual clock before advancing = " + clock.time) + for (i <- 1 to secondNumBatches.toInt) { + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + logInfo("Manual clock after advancing = " + clock.time) + Thread.sleep(batchDuration.milliseconds) + + // Check whether the earlier checkpoint files are deleted + checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) + + // Restart stream computation using the checkpoint file and check whether + // checkpointed RDDs have been restored or not + ssc.stop() val sscNew = new StreamingContext(checkpointDir) - //sscNew.checkpoint(null, null) - val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs) - verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) + val stateStreamNew = sscNew.graph.getOutputStreams().head.dependencies.head.dependencies.head + logInfo("Restored data of state stream = \n[" + stateStreamNew.generatedRDDs.mkString("\n") + "]") + assert(!stateStreamNew.generatedRDDs.isEmpty, "No restored RDDs in state stream") + sscNew.stop() } @@ -69,9 +113,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val input = (1 to n).map(x => Seq("a")).toSeq val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) val operation = (st: DStream[String]) => { - st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, Seconds(w), Seconds(1)) + st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, batchDuration * 4, batchDuration) } - for (i <- Seq(3, 5, 7)) { + for (i <- Seq(2, 3, 4)) { testCheckpointedOperation(input, operation, output, i) } } @@ -85,12 +129,45 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { } st.map(x => (x, 1)) .updateStateByKey[RichInt](updateFunc) - .checkpoint(Seconds(5)) + .checkpoint(Seconds(2)) .map(t => (t._1, t._2.self)) } - for (i <- Seq(3, 5, 7)) { + for (i <- Seq(2, 3, 4)) { testCheckpointedOperation(input, operation, output, i) } } + + + def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + initialNumBatches: Int + ) { + + // Current code assumes that: + // number of inputs = number of outputs = number of batches to be run + val totalNumBatches = input.size + val nextNumBatches = totalNumBatches - initialNumBatches + val initialNumExpectedOutputs = initialNumBatches + val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + + // Do half the computation (half the number of batches), create checkpoint file and quit + + val ssc = setupStreams[U, V](input, operation) + val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) + verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) + Thread.sleep(1000) + + // Restart and complete the computation from checkpoint file + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation " + + "\n-------------------------------------------\n" + ) + val sscNew = new StreamingContext(checkpointDir) + val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs) + verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) + } } \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index e441feea19..b8c7f99603 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -57,21 +57,21 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu */ trait TestSuiteBase extends FunSuite with Logging { - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + def framework = "TestSuiteBase" - def framework() = "TestSuiteBase" + def master = "local[2]" - def master() = "local[2]" + def batchDuration = Seconds(1) - def batchDuration() = Seconds(1) + def checkpointDir = null.asInstanceOf[String] - def checkpointDir() = null.asInstanceOf[String] + def checkpointInterval = batchDuration - def checkpointInterval() = batchDuration + def numInputPartitions = 2 - def numInputPartitions() = 2 + def maxWaitTimeMillis = 10000 - def maxWaitTimeMillis() = 10000 + def actuallyWait = false def setupStreams[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], @@ -82,7 +82,7 @@ trait TestSuiteBase extends FunSuite with Logging { val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) if (checkpointDir != null) { - ssc.checkpoint(checkpointDir, checkpointInterval()) + ssc.checkpoint(checkpointDir, checkpointInterval) } // Setup the stream computation @@ -104,7 +104,7 @@ trait TestSuiteBase extends FunSuite with Logging { val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) if (checkpointDir != null) { - ssc.checkpoint(checkpointDir, checkpointInterval()) + ssc.checkpoint(checkpointDir, checkpointInterval) } // Setup the stream computation @@ -118,12 +118,19 @@ trait TestSuiteBase extends FunSuite with Logging { ssc } + /** + * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and + * returns the collected output. It will wait until `numExpectedOutput` number of + * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached. + */ def runStreams[V: ClassManifest]( ssc: StreamingContext, numBatches: Int, numExpectedOutput: Int ): Seq[Seq[V]] = { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + assert(numBatches > 0, "Number of batches to run stream computation is zero") assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) @@ -139,7 +146,15 @@ trait TestSuiteBase extends FunSuite with Logging { // Advance manual clock val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.time) - clock.addToTime(numBatches * batchDuration.milliseconds) + if (actuallyWait) { + for (i <- 1 to numBatches) { + logInfo("Actually waiting for " + batchDuration) + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + } else { + clock.addToTime(numBatches * batchDuration.milliseconds) + } logInfo("Manual clock after advancing = " + clock.time) // Wait until expected number of output items have been generated diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index e282f0fdd5..3e20e16708 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -5,11 +5,11 @@ import collection.mutable.ArrayBuffer class WindowOperationsSuite extends TestSuiteBase { - override def framework() = "WindowOperationsSuite" + override def framework = "WindowOperationsSuite" - override def maxWaitTimeMillis() = 20000 + override def maxWaitTimeMillis = 20000 - override def batchDuration() = Seconds(1) + override def batchDuration = Seconds(1) val largerSlideInput = Seq( Seq(("a", 1)), -- cgit v1.2.3 From 355c8e4b17cc3e67b1e18cc24e74d88416b5779b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 9 Nov 2012 16:28:45 -0800 Subject: Fixed deadlock in BlockManager. --- .../main/scala/spark/storage/BlockManager.scala | 111 ++++++++++----------- .../src/main/scala/spark/storage/MemoryStore.scala | 79 +++++++++------ 2 files changed, 101 insertions(+), 89 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index bd9155ef29..8c7b1417be 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -50,16 +50,6 @@ private[spark] case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message) - -private[spark] class BlockLocker(numLockers: Int) { - private val hashLocker = Array.fill(numLockers)(new Object()) - - def getLock(blockId: String): Object = { - return hashLocker(math.abs(blockId.hashCode % numLockers)) - } -} - - private[spark] class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, maxMemory: Long) extends Logging { @@ -87,10 +77,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - private val NUM_LOCKS = 337 - private val locker = new BlockLocker(NUM_LOCKS) - - private val blockInfo = new ConcurrentHashMap[String, BlockInfo]() + private val blockInfo = new ConcurrentHashMap[String, BlockInfo](1000) private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private[storage] val diskStore: BlockStore = @@ -110,7 +97,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val maxBytesInFlight = System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 + // Whether to compress broadcast variables that are stored val compressBroadcast = System.getProperty("spark.broadcast.compress", "true").toBoolean + // Whether to compress shuffle output that are stored val compressShuffle = System.getProperty("spark.shuffle.compress", "true").toBoolean // Whether to compress RDD partitions that are stored serialized val compressRdds = System.getProperty("spark.rdd.compress", "false").toBoolean @@ -150,28 +139,27 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. */ def reportBlockStatus(blockId: String) { - locker.getLock(blockId).synchronized { - val curLevel = blockInfo.get(blockId) match { - case null => - StorageLevel.NONE - case info => + val (curLevel, inMemSize, onDiskSize) = blockInfo.get(blockId) match { + case null => + (StorageLevel.NONE, 0L, 0L) + case info => + info.synchronized { info.level match { case null => - StorageLevel.NONE + (StorageLevel.NONE, 0L, 0L) case level => val inMem = level.useMemory && memoryStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) - new StorageLevel(onDisk, inMem, level.deserialized, level.replication) + ( + new StorageLevel(onDisk, inMem, level.deserialized, level.replication), + if (inMem) memoryStore.getSize(blockId) else 0L, + if (onDisk) diskStore.getSize(blockId) else 0L + ) } - } - master.mustHeartBeat(HeartBeat( - blockManagerId, - blockId, - curLevel, - if (curLevel.useMemory) memoryStore.getSize(blockId) else 0L, - if (curLevel.useDisk) diskStore.getSize(blockId) else 0L)) - logDebug("Told master about block " + blockId) + } } + master.mustHeartBeat(HeartBeat(blockManagerId, blockId, curLevel, inMemSize, onDiskSize)) + logDebug("Told master about block " + blockId) } /** @@ -213,9 +201,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - locker.getLock(blockId).synchronized { - val info = blockInfo.get(blockId) - if (info != null) { + val info = blockInfo.get(blockId) + if (info != null) { + info.synchronized { info.waitForReady() // In case the block is still being put() by another thread val level = info.level logDebug("Level for block " + blockId + " is " + level) @@ -273,9 +261,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } } - } else { - logDebug("Block " + blockId + " not registered locally") } + } else { + logDebug("Block " + blockId + " not registered locally") } return None } @@ -298,9 +286,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - locker.getLock(blockId).synchronized { - val info = blockInfo.get(blockId) - if (info != null) { + val info = blockInfo.get(blockId) + if (info != null) { + info.synchronized { info.waitForReady() // In case the block is still being put() by another thread val level = info.level logDebug("Level for block " + blockId + " is " + level) @@ -338,10 +326,11 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new Exception("Block " + blockId + " not found on disk, though it should be") } } - } else { - logDebug("Block " + blockId + " not registered locally") } + } else { + logDebug("Block " + blockId + " not registered locally") } + return None } @@ -583,7 +572,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Size of the block in bytes (to return to caller) var size = 0L - locker.getLock(blockId).synchronized { + myInfo.synchronized { logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") @@ -681,7 +670,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m null } - locker.getLock(blockId).synchronized { + myInfo.synchronized { logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") @@ -779,26 +768,30 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { logInfo("Dropping block " + blockId + " from memory") - locker.getLock(blockId).synchronized { - val info = blockInfo.get(blockId) - val level = info.level - if (level.useDisk && !diskStore.contains(blockId)) { - logInfo("Writing block " + blockId + " to disk") - data match { - case Left(elements) => - diskStore.putValues(blockId, elements, level, false) - case Right(bytes) => - diskStore.putBytes(blockId, bytes, level) + val info = blockInfo.get(blockId) + if (info != null) { + info.synchronized { + val level = info.level + if (level.useDisk && !diskStore.contains(blockId)) { + logInfo("Writing block " + blockId + " to disk") + data match { + case Left(elements) => + diskStore.putValues(blockId, elements, level, false) + case Right(bytes) => + diskStore.putBytes(blockId, bytes, level) + } + } + memoryStore.remove(blockId) + if (info.tellMaster) { + reportBlockStatus(blockId) + } + if (!level.useDisk) { + // The block is completely gone from this node; forget it so we can put() it again later. + blockInfo.remove(blockId) } } - memoryStore.remove(blockId) - if (info.tellMaster) { - reportBlockStatus(blockId) - } - if (!level.useDisk) { - // The block is completely gone from this node; forget it so we can put() it again later. - blockInfo.remove(blockId) - } + } else { + // The block has already been dropped } } diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala index 773970446a..09769d1f7d 100644 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -17,13 +17,16 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) private var currentMemory = 0L + // Object used to ensure that only one thread is putting blocks and if necessary, dropping + // blocks from the memory store. + private val putLock = new Object() logInfo("MemoryStore started with capacity %s.".format(Utils.memoryBytesToString(maxMemory))) def freeMemory: Long = maxMemory - currentMemory override def getSize(blockId: String): Long = { - synchronized { + entries.synchronized { entries.get(blockId).size } } @@ -38,8 +41,6 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) tryToPut(blockId, elements, sizeEstimate, true) } else { val entry = new Entry(bytes, bytes.limit, false) - ensureFreeSpace(blockId, bytes.limit) - synchronized { entries.put(blockId, entry) } tryToPut(blockId, bytes, bytes.limit, false) } } @@ -63,7 +64,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def getBytes(blockId: String): Option[ByteBuffer] = { - val entry = synchronized { + val entry = entries.synchronized { entries.get(blockId) } if (entry == null) { @@ -76,7 +77,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def getValues(blockId: String): Option[Iterator[Any]] = { - val entry = synchronized { + val entry = entries.synchronized { entries.get(blockId) } if (entry == null) { @@ -90,7 +91,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def remove(blockId: String) { - synchronized { + entries.synchronized { val entry = entries.get(blockId) if (entry != null) { entries.remove(blockId) @@ -104,7 +105,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def clear() { - synchronized { + entries.synchronized { entries.clear() } logInfo("MemoryStore cleared") @@ -125,12 +126,22 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Try to put in a set of values, if we can free up enough space. The value should either be * an ArrayBuffer if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) * size must also be passed by the caller. + * + * Locks on the object putLock to ensure that all the put requests and its associated block + * dropping is done by only on thread at a time. Otherwise while one thread is dropping + * blocks to free memory for one block, another thread may use up the freed space for + * another block. */ private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = { - synchronized { + // TODO: Its possible to optimize the locking by locking entries only when selecting blocks + // to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been + // released, it must be ensured that those to-be-dropped blocks are not double counted for + // freeing up more space for another block that needs to be put. Only then the actually dropping + // of blocks (and writing to disk if necessary) can proceed in parallel. + putLock.synchronized { if (ensureFreeSpace(blockId, size)) { val entry = new Entry(value, size, deserialized) - entries.put(blockId, entry) + entries.synchronized { entries.put(blockId, entry) } currentMemory += size if (deserialized) { logInfo("Block %s stored as values to memory (estimated size %s, free %s)".format( @@ -160,8 +171,8 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that * don't fit into memory that we want to avoid). * - * Assumes that a lock on the MemoryStore is held by the caller. (Otherwise, the freed space - * might fill up before the caller puts in their new value.) + * Assumes that a lock is held by the caller to ensure only one thread is dropping blocks. + * Otherwise, the freed space may fill up before the caller puts in their new value. */ private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = { logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( @@ -172,36 +183,44 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) return false } - // TODO: This should relinquish the lock on the MemoryStore while flushing out old blocks - // in order to allow parallelism in writing to disk if (maxMemory - currentMemory < space) { val rddToAdd = getRddId(blockIdToAdd) val selectedBlocks = new ArrayBuffer[String]() var selectedMemory = 0L - val iterator = entries.entrySet().iterator() - while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { - val pair = iterator.next() - val blockId = pair.getKey - if (rddToAdd != null && rddToAdd == getRddId(blockId)) { - logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + - "block from the same RDD") - return false + // This is synchronized to ensure that the set of entries is not changed + // (because of getValue or getBytes) while traversing the iterator, as that + // can lead to exceptions. + entries.synchronized { + val iterator = entries.entrySet().iterator() + while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { + val pair = iterator.next() + val blockId = pair.getKey + if (rddToAdd != null && rddToAdd == getRddId(blockId)) { + logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + + "block from the same RDD") + return false + } + selectedBlocks += blockId + selectedMemory += pair.getValue.size } - selectedBlocks += blockId - selectedMemory += pair.getValue.size } if (maxMemory - (currentMemory - selectedMemory) >= space) { logInfo(selectedBlocks.size + " blocks selected for dropping") for (blockId <- selectedBlocks) { - val entry = entries.get(blockId) - val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + val entry = entries.synchronized { entries.get(blockId) } + // This should never be null as only one thread should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + val data = if (entry.deserialized) { + Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + } + blockManager.dropFromMemory(blockId, data) } - blockManager.dropFromMemory(blockId, data) } return true } else { @@ -212,7 +231,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def contains(blockId: String): Boolean = { - synchronized { entries.containsKey(blockId) } + entries.synchronized { entries.containsKey(blockId) } } } -- cgit v1.2.3 From 04e9e9d93c512f856116bc2c99c35dfb48b4adee Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 11 Nov 2012 08:54:21 -0800 Subject: Refactored BlockManagerMaster (not BlockManagerMasterActor) to simplify the code and fix live lock problem in unlimited attempts to contact the master. Also added testcases in the BlockManagerSuite to test BlockManagerMaster methods getPeers and getLocations. --- .../main/scala/spark/storage/BlockManager.scala | 14 +- .../scala/spark/storage/BlockManagerMaster.scala | 281 +++++++-------------- .../scala/spark/storage/BlockManagerSuite.scala | 30 ++- 3 files changed, 127 insertions(+), 198 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 8c7b1417be..70d6d8369d 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -120,8 +120,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * BlockManagerWorker actor. */ private def initialize() { - master.mustRegisterBlockManager( - RegisterBlockManager(blockManagerId, maxMemory)) + master.registerBlockManager(blockManagerId, maxMemory) BlockManagerWorker.startBlockManagerWorker(this) } @@ -158,7 +157,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } } - master.mustHeartBeat(HeartBeat(blockManagerId, blockId, curLevel, inMemSize, onDiskSize)) + master.updateBlockInfo(blockManagerId, blockId, curLevel, inMemSize, onDiskSize) logDebug("Told master about block " + blockId) } @@ -167,7 +166,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def getLocations(blockId: String): Seq[String] = { val startTimeMs = System.currentTimeMillis - var managers = master.mustGetLocations(GetLocations(blockId)) + var managers = master.getLocations(blockId) val locations = managers.map(_.ip) logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations @@ -178,8 +177,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis - val locations = master.mustGetLocationsMultipleBlockIds( - GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray + val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -343,7 +341,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } logDebug("Getting remote block " + blockId) // Get locations of block - val locations = master.mustGetLocations(GetLocations(blockId)) + val locations = master.getLocations(blockId) // Get block from remote locations for (loc <- locations) { @@ -721,7 +719,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) if (cachedPeers == null) { - cachedPeers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + cachedPeers = master.getPeers(blockManagerId, level.replication - 1) } for (peer: BlockManagerId <- cachedPeers) { val start = System.nanoTime diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index b3345623b3..4d5ee8318c 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -26,7 +26,7 @@ case class RegisterBlockManager( extends ToBlockManagerMaster private[spark] -class HeartBeat( +class UpdateBlockInfo( var blockManagerId: BlockManagerId, var blockId: String, var storageLevel: StorageLevel, @@ -57,17 +57,17 @@ class HeartBeat( } private[spark] -object HeartBeat { +object UpdateBlockInfo { def apply(blockManagerId: BlockManagerId, blockId: String, storageLevel: StorageLevel, memSize: Long, - diskSize: Long): HeartBeat = { - new HeartBeat(blockManagerId, blockId, storageLevel, memSize, diskSize) + diskSize: Long): UpdateBlockInfo = { + new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize) } // For pattern-matching - def unapply(h: HeartBeat): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { + def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) } } @@ -182,8 +182,8 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor case RegisterBlockManager(blockManagerId, maxMemSize) => register(blockManagerId, maxMemSize) - case HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) => - heartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) + case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) case GetLocations(blockId) => getLocations(blockId) @@ -233,7 +233,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor sender ! true } - private def heartBeat( + private def updateBlockInfo( blockManagerId: BlockManagerId, blockId: String, storageLevel: StorageLevel, @@ -245,7 +245,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (blockId == null) { blockManagerInfo(blockManagerId).updateLastSeenMs() - logDebug("Got in heartBeat 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Got in updateBlockInfo 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) sender ! true } @@ -350,211 +350,124 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean) extends Logging { - val AKKA_ACTOR_NAME: String = "BlockMasterManager" - val REQUEST_RETRY_INTERVAL_MS = 100 - val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost") - val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt - val DEFAULT_MANAGER_IP: String = Utils.localHostName() - val DEFAULT_MANAGER_PORT: String = "10902" - + val actorName = "BlockMasterManager" val timeout = 10.seconds - var masterActor: ActorRef = null + val maxAttempts = 5 - if (isMaster) { - masterActor = actorSystem.actorOf( - Props(new BlockManagerMasterActor(isLocal)), name = AKKA_ACTOR_NAME) + var masterActor = if (isMaster) { + val actor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), name = actorName) logInfo("Registered BlockManagerMaster Actor") + actor } else { - val url = "akka://spark@%s:%s/user/%s".format( - DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT, AKKA_ACTOR_NAME) + val host = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val url = "akka://spark@%s:%s/user/%s".format(host, port, actorName) + val actor = actorSystem.actorFor(url) logInfo("Connecting to BlockManagerMaster: " + url) - masterActor = actorSystem.actorFor(url) + actor } - def stop() { - if (masterActor != null) { - communicate(StopBlockManagerMaster) - masterActor = null - logInfo("BlockManagerMaster stopped") + /** + * Send a message to the master actor and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + private def ask[T](message: Any): T = { + // TODO: Consider removing multiple attempts + if (masterActor == null) { + throw new SparkException("Error sending message to BlockManager as masterActor is null " + + "[message = " + message + "]") } - } - - // Send a message to the master actor and get its result within a default timeout, or - // throw a SparkException if this fails. - def askMaster(message: Any): Any = { - try { - val future = masterActor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with BlockManagerMaster", e) + var attempts = 0 + var lastException: Exception = null + while (attempts < maxAttempts) { + attempts += 1 + try { + val future = masterActor.ask(message)(timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new Exception("BlockManagerMaster returned null") + } + return result.asInstanceOf[T] + } catch { + case ie: InterruptedException => + throw ie + case e: Exception => + lastException = e + logWarning( + "Error sending message to BlockManagerMaster in " + attempts + " attempts", e) + } + Thread.sleep(100) } + throw new SparkException( + "Error sending message to BlockManagerMaster [message = " + message + "]", lastException) } - // Send a one-way message to the master actor, to which we expect it to reply with true. - def communicate(message: Any) { - if (askMaster(message) != true) { - throw new SparkException("Error reply received from BlockManagerMaster") + /** + * Send a one-way message to the master actor, to which we expect it to reply with true + */ + private def tell(message: Any) { + if (!ask[Boolean](message)) { + throw new SparkException("Telling master a message returned false") } } - def notifyADeadHost(host: String) { - communicate(RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)) - logInfo("Removed " + host + " successfully in notifyADeadHost") - } - - def mustRegisterBlockManager(msg: RegisterBlockManager) { + /** + * Register the BlockManager's id with the master + */ + def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long) { logInfo("Trying to register BlockManager") - while (! syncRegisterBlockManager(msg)) { - logWarning("Failed to register " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - logInfo("Done registering BlockManager") - } - - def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = { - //val masterActor = RemoteActor.select(node, name) - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - communicate(msg) - logInfo("BlockManager registered successfully @ syncRegisterBlockManager") - logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return true - } catch { - case e: Exception => - logError("Failed in syncRegisterBlockManager", e) - return false - } + tell(RegisterBlockManager(blockManagerId, maxMemSize)) + logInfo("Registered BlockManager") } - def mustHeartBeat(msg: HeartBeat) { - while (! syncHeartBeat(msg)) { - logWarning("Failed to send heartbeat" + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - } - - def syncHeartBeat(msg: HeartBeat): Boolean = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncHeartBeat " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs)) - - try { - communicate(msg) - logDebug("Heartbeat sent successfully") - logDebug("Got in syncHeartBeat 1 " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs)) - return true - } catch { - case e: Exception => - logError("Failed in syncHeartBeat", e) - return false - } + def updateBlockInfo( + blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long + ) { + tell(UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) + logInfo("Updated info of block " + blockId) } - def mustGetLocations(msg: GetLocations): Seq[BlockManagerId] = { - var res = syncGetLocations(msg) - while (res == null) { - logInfo("Failed to get locations " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocations(msg) - } - return res + /** Get locations of the blockId from the master */ + def getLocations(blockId: String): Seq[BlockManagerId] = { + ask[Seq[BlockManagerId]](GetLocations(blockId)) } - def syncGetLocations(msg: GetLocations): Seq[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[ArrayBuffer[BlockManagerId]] - if (answer != null) { - logDebug("GetLocations successful") - logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetLocations") - return null - } - } catch { - case e: Exception => - logError("GetLocations failed", e) - return null - } + /** Get locations of multiple blockIds from the master */ + def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + ask[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } - def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg) - while (res == null) { - logWarning("Failed to GetLocationsMultipleBlockIds " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocationsMultipleBlockIds(msg) + /** Get ids of other nodes in the cluster from the master */ + def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { + val result = ask[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) + if (result.length != numPeers) { + throw new SparkException( + "Error getting peers, only got " + result.size + " instead of " + numPeers) } - return res + result } - def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[Seq[Seq[BlockManagerId]]] - if (answer != null) { - logDebug("GetLocationsMultipleBlockIds successful") - logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + - Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetLocationsMultipleBlockIds") - return null - } - } catch { - case e: Exception => - logError("GetLocationsMultipleBlockIds failed", e) - return null - } + /** Notify the master of a dead node */ + def notifyADeadHost(host: String) { + tell(RemoveHost(host + ":10902")) + logInfo("Told BlockManagerMaster to remove dead host " + host) } - def mustGetPeers(msg: GetPeers): Seq[BlockManagerId] = { - var res = syncGetPeers(msg) - while ((res == null) || (res.length != msg.size)) { - logInfo("Failed to get peers " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetPeers(msg) - } - - return res + /** Get the memory status form the master */ + def getMemoryStatus(): Map[BlockManagerId, (Long, Long)] = { + ask[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } - def syncGetPeers(msg: GetPeers): Seq[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[Seq[BlockManagerId]] - if (answer != null) { - logDebug("GetPeers successful") - logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetPeers") - return null - } - } catch { - case e: Exception => - logError("GetPeers failed", e) - return null + /** Stop the master actor, called only on the Spark master node */ + def stop() { + if (masterActor != null) { + tell(StopBlockManagerMaster) + masterActor = null + logInfo("BlockManagerMaster stopped") } } - - def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askMaster(GetMemoryStatus).asInstanceOf[Map[BlockManagerId, (Long, Long)]] - } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index b9c19e61cd..0e78228134 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -20,9 +20,11 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var oldOops: String = null // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + System.setProperty("spark.kryoserializer.buffer.mb", "1") val serializer = new KryoSerializer before { + actorSystem = ActorSystem("test") master = new BlockManagerMaster(actorSystem, true, true) @@ -55,7 +57,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } } - test("manager-master interaction") { + test("master + 1 manager interaction") { store = new BlockManager(master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -72,17 +74,33 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.getSingle("a3") != None, "a3 was not in store") // Checking whether master knows about the blocks or not - assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") - assert(master.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2") - assert(master.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") + assert(master.getLocations("a1").size === 1, "master was not told about a1") + assert(master.getLocations("a2").size === 1, "master was not told about a2") + assert(master.getLocations("a3").size === 0, "master was told about a3") // Drop a1 and a2 from memory; this should be reported back to the master store.dropFromMemory("a1", null) store.dropFromMemory("a2", null) assert(store.getSingle("a1") === None, "a1 not removed from store") assert(store.getSingle("a2") === None, "a2 not removed from store") - assert(master.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1") - assert(master.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2") + assert(master.getLocations("a1").size === 0, "master did not remove a1") + assert(master.getLocations("a2").size === 0, "master did not remove a2") + } + + test("master + 2 managers interaction") { + store = new BlockManager(master, serializer, 2000) + val otherStore = new BlockManager(master, new KryoSerializer, 2000) + + val peers = master.getPeers(store.blockManagerId, 1) + assert(peers.size === 1, "master did not return the other manager as a peer") + assert(peers.head === otherStore.blockManagerId, "peer returned by master is not the other manager") + + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2) + otherStore.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2) + assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1") + assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2") } test("in-memory LRU storage") { -- cgit v1.2.3 From 4a1be7e0dbf0031d85b91dc1132fe101d87ba097 Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 12 Nov 2012 10:56:35 -0800 Subject: Refactor BlockManager UI and adding worker details. --- core/src/main/scala/spark/RDD.scala | 7 +- core/src/main/scala/spark/SparkContext.scala | 2 - .../scala/spark/storage/BlockManagerMaster.scala | 11 ++- .../main/scala/spark/storage/BlockManagerUI.scala | 51 ++++---------- .../main/scala/spark/storage/StorageLevel.scala | 9 +++ .../main/scala/spark/storage/StorageUtils.scala | 78 ++++++++++++++++++++++ core/src/main/twirl/spark/storage/index.scala.html | 22 ++++-- core/src/main/twirl/spark/storage/rdd.scala.html | 35 ++++++---- .../main/twirl/spark/storage/rdd_row.scala.html | 18 ----- .../main/twirl/spark/storage/rdd_table.scala.html | 16 ++++- .../twirl/spark/storage/worker_table.scala.html | 24 +++++++ 11 files changed, 186 insertions(+), 87 deletions(-) create mode 100644 core/src/main/scala/spark/storage/StorageUtils.scala delete mode 100644 core/src/main/twirl/spark/storage/rdd_row.scala.html create mode 100644 core/src/main/twirl/spark/storage/worker_table.scala.html (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index dc757dc6aa..3669bda2d2 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -86,6 +86,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial @transient val dependencies: List[Dependency[_]] // Methods available on all RDDs: + + // A friendly name for this RDD + var name: String = null /** Record user function generating this RDD. */ private[spark] val origin = Utils.getSparkCallSite @@ -108,8 +111,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial private var storageLevel: StorageLevel = StorageLevel.NONE /* Assign a name to this RDD */ - def name(name: String) = { - sc.rddNames(this.id) = name + def setName(_name: String) = { + name = _name this } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 71c9dcd017..7ea0f6f9e0 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -113,8 +113,6 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new ConcurrentHashMap[Int, RDD[_]]() - // A HashMap for friendly RDD Names - private[spark] val rddNames = new ConcurrentHashMap[Int, String]() // Add each JAR given through the constructor jars.foreach { addJar(_) } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 3fc9b629c1..beafdda9d1 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -4,7 +4,7 @@ import java.io._ import java.util.{HashMap => JHashMap} import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.Random import akka.actor._ @@ -95,10 +95,7 @@ private[spark] case class GetStorageStatus extends ToBlockManagerMaster private[spark] -case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) - -private[spark] -case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, remainingMem: Long, blocks: Map[String, BlockStatus]) +case class BlockStatus(blockManagerId: BlockManagerId, storageLevel: StorageLevel, memSize: Long, diskSize: Long) private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { @@ -135,7 +132,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (storageLevel.isValid) { // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) + _blocks.put(blockId, BlockStatus(blockManagerId, storageLevel, memSize, diskSize)) if (storageLevel.useMemory) { _remainingMem -= memSize logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( @@ -237,7 +234,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor private def getStorageStatus() { val res = blockManagerInfo.map { case(blockManagerId, info) => - StorageStatus(blockManagerId, info.maxMem, info.remainingMem, info.blocks.asScala) + StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap) } sender ! res } diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 635c096c87..35cbd59280 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -12,6 +12,7 @@ import scala.collection.mutable.ArrayBuffer import spark.{Logging, SparkContext, SparkEnv} import spark.util.AkkaUtils + private[spark] object BlockManagerUI extends Logging { @@ -32,9 +33,6 @@ object BlockManagerUI extends Logging { } -private[spark] -case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numPartitions: Int, memSize: Long, diskSize: Long) private[spark] class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, @@ -49,21 +47,17 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, // Request the current storage status from the Master val future = master ? GetStorageStatus future.map { status => - val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] + val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray // Calculate macro-level statistics - val maxMem = storageStati.map(_.maxMem).reduce(_+_) - val remainingMem = storageStati.map(_.remainingMem).reduce(_+_) - val diskSpaceUsed = storageStati.flatMap(_.blocks.values.map(_.diskSize)) + val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) + val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) + val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)) .reduceOption(_+_).getOrElse(0L) - // Filter out everything that's not and rdd. - val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => - k.startsWith("rdd") - }.toMap - val rdds = rddInfoFromBlockStati(rddBlocks) + val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) - spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds.toList) + spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) } }}} ~ get { path("rdd") { parameter("id") { id => { completeWith { @@ -71,13 +65,13 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, future.map { status => val prefix = "rdd_" + id.toString - val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] - val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => - k.startsWith(prefix) - }.toMap - val rddInfo = rddInfoFromBlockStati(rddBlocks).first - spark.storage.html.rdd.render(rddInfo, rddBlocks) + val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray + val filteredStorageStatusList = StorageUtils.filterStorageStatusByPrefix(storageStatusList, prefix) + + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).first + + spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) } }}}}} ~ @@ -87,23 +81,6 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, } - private def rddInfoFromBlockStati(infos: Map[String, BlockStatus]) : Array[RDDInfo] = { - infos.groupBy { case(k,v) => - // Group by rdd name, ignore the partition name - k.substring(0,k.lastIndexOf('_')) - }.map { case(k,v) => - val blockStati = v.map(_._2).toArray - // Add up memory and disk sizes - val tmp = blockStati.map { x => (x.memSize, x.diskSize)}.reduce { (x,y) => - (x._1 + y._1, x._2 + y._2) - } - // Get the friendly name for the rdd, if available. - // This is pretty hacky, is there a better way? - val rddId = k.split("_").last.toInt - val rddName : String = Option(sc.rddNames.get(rddId)).getOrElse(k) - val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel - RDDInfo(rddId, rddName, rddStorageLevel, blockStati.length, tmp._1, tmp._2) - }.toArray - } + } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index c497f03e0c..97d8c7566d 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -68,6 +68,15 @@ class StorageLevel( override def toString: String = "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) + + def description : String = { + var result = "" + result += (if (useDisk) "Disk " else "") + result += (if (useMemory) "Memory " else "") + result += (if (deserialized) "Deserialized " else "Serialized") + result += "%sx Replicated".format(replication) + result + } } object StorageLevel { diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala new file mode 100644 index 0000000000..ebc7390ee5 --- /dev/null +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -0,0 +1,78 @@ +package spark.storage + +import spark.SparkContext + +private[spark] +case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, + blocks: Map[String, BlockStatus]) { + + def memUsed(blockPrefix: String = "") = { + blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize). + reduceOption(_+_).getOrElse(0l) + } + + def diskUsed(blockPrefix: String = "") = { + blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize). + reduceOption(_+_).getOrElse(0l) + } + + def memRemaining : Long = maxMem - memUsed() + +} + +case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, + numPartitions: Int, memSize: Long, diskSize: Long, locations: Array[BlockManagerId]) + + +/* Helper methods for storage-related objects */ +private[spark] +object StorageUtils { + + /* Given the current storage status of the BlockManager, returns information for each RDD */ + def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus], + sc: SparkContext) : Array[RDDInfo] = { + rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) + } + + /* Given a list of BlockStatus objets, returns information for each RDD */ + def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], + sc: SparkContext) : Array[RDDInfo] = { + // Find all RDD Blocks (ignore broadcast variables) + val rddBlocks = infos.filterKeys(_.startsWith("rdd")) + + // Group by rddId, ignore the partition name + val groupedRddBlocks = infos.groupBy { case(k, v) => + k.substring(0,k.lastIndexOf('_')) + }.mapValues(_.values.toArray) + + // For each RDD, generate an RDDInfo object + groupedRddBlocks.map { case(rddKey, rddBlocks) => + + // Add up memory and disk sizes + val memSize = rddBlocks.map(_.memSize).reduce(_ + _) + val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _) + + // Find the id of the RDD, e.g. rdd_1 => 1 + val rddId = rddKey.split("_").last.toInt + // Get the friendly name for the rdd, if available. + val rddName = Option(sc.persistentRdds.get(rddId).name).getOrElse(rddKey) + val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel + + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize, + rddBlocks.map(_.blockManagerId)) + }.toArray + } + + /* Removes all BlockStatus object that are not part of a block prefix */ + def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], + prefix: String) : Array[StorageStatus] = { + + storageStatusList.map { status => + val newBlocks = status.blocks.filterKeys(_.startsWith(prefix)) + //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _) + StorageStatus(status.blockManagerId, status.maxMem, newBlocks) + } + + } + +} \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/index.scala.html b/core/src/main/twirl/spark/storage/index.scala.html index fa7dad51ee..2b337f6133 100644 --- a/core/src/main/twirl/spark/storage/index.scala.html +++ b/core/src/main/twirl/spark/storage/index.scala.html @@ -1,4 +1,5 @@ -@(maxMem: Long, remainingMem: Long, diskSpaceUsed: Long, rdds: List[spark.storage.RDDInfo]) +@(maxMem: Long, remainingMem: Long, diskSpaceUsed: Long, rdds: Array[spark.storage.RDDInfo], storageStatusList: Array[spark.storage.StorageStatus]) +@import spark.Utils @spark.common.html.layout(title = "Storage Dashboard") { @@ -7,16 +8,16 @@
  • Memory: - @{spark.Utils.memoryBytesToString(maxMem - remainingMem)} Used - (@{spark.Utils.memoryBytesToString(remainingMem)} Available)
  • -
  • Disk: @{spark.Utils.memoryBytesToString(diskSpaceUsed)} Used
  • + @{Utils.memoryBytesToString(maxMem - remainingMem)} Used + (@{Utils.memoryBytesToString(remainingMem)} Available) +
  • Disk: @{Utils.memoryBytesToString(diskSpaceUsed)} Used

- +

RDD Summary

@@ -25,4 +26,15 @@
+
+ + +
+
+

Worker Summary

+
+ @worker_table(storageStatusList) +
+
+ } \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html index 075289c826..ac7f8c981f 100644 --- a/core/src/main/twirl/spark/storage/rdd.scala.html +++ b/core/src/main/twirl/spark/storage/rdd.scala.html @@ -1,4 +1,5 @@ -@(rddInfo: spark.storage.RDDInfo, blocks: Map[String, spark.storage.BlockStatus]) +@(rddInfo: spark.storage.RDDInfo, storageStatusList: Array[spark.storage.StorageStatus]) +@import spark.Utils @spark.common.html.layout(title = "RDD Info ") { @@ -8,21 +9,18 @@
  • Storage Level: - @(if (rddInfo.storageLevel.useDisk) "Disk" else "") - @(if (rddInfo.storageLevel.useMemory) "Memory" else "") - @(if (rddInfo.storageLevel.deserialized) "Deserialized" else "") - @(rddInfo.storageLevel.replication)x Replicated + @(rddInfo.storageLevel.description)
  • Partitions: @(rddInfo.numPartitions)
  • Memory Size: - @{spark.Utils.memoryBytesToString(rddInfo.memSize)} + @{Utils.memoryBytesToString(rddInfo.memSize)}
  • Disk Size: - @{spark.Utils.memoryBytesToString(rddInfo.diskSize)} + @{Utils.memoryBytesToString(rddInfo.diskSize)}
@@ -36,6 +34,7 @@

RDD Summary


+ @@ -47,17 +46,14 @@ - @blocks.map { case (k,v) => + @storageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1).map { case (k,v) => - - + + } @@ -67,4 +63,15 @@ +
+ + +
+
+

Worker Summary

+
+ @worker_table(storageStatusList, "rdd_" + rddInfo.id ) +
+
+ } \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd_row.scala.html b/core/src/main/twirl/spark/storage/rdd_row.scala.html deleted file mode 100644 index 3dd9944e3b..0000000000 --- a/core/src/main/twirl/spark/storage/rdd_row.scala.html +++ /dev/null @@ -1,18 +0,0 @@ -@(rdd: spark.storage.RDDInfo) - - - - - - - - \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd_table.scala.html b/core/src/main/twirl/spark/storage/rdd_table.scala.html index 24f55ccefb..af801cf229 100644 --- a/core/src/main/twirl/spark/storage/rdd_table.scala.html +++ b/core/src/main/twirl/spark/storage/rdd_table.scala.html @@ -1,4 +1,5 @@ -@(rdds: List[spark.storage.RDDInfo]) +@(rdds: Array[spark.storage.RDDInfo]) +@import spark.Utils
@k - @(if (v.storageLevel.useDisk) "Disk" else "") - @(if (v.storageLevel.useMemory) "Memory" else "") - @(if (v.storageLevel.deserialized) "Deserialized" else "") - @(v.storageLevel.replication)x Replicated + @(v.storageLevel.description) @{spark.Utils.memoryBytesToString(v.memSize)}@{spark.Utils.memoryBytesToString(v.diskSize)}@{Utils.memoryBytesToString(v.memSize)}@{Utils.memoryBytesToString(v.diskSize)}
- - @rdd.name - - - @(if (rdd.storageLevel.useDisk) "Disk" else "") - @(if (rdd.storageLevel.useMemory) "Memory" else "") - @(if (rdd.storageLevel.deserialized) "Deserialized" else "") - @(rdd.storageLevel.replication)x Replicated - @rdd.numPartitions@{spark.Utils.memoryBytesToString(rdd.memSize)}@{spark.Utils.memoryBytesToString(rdd.diskSize)}
@@ -12,7 +13,18 @@ @for(rdd <- rdds) { - @rdd_row(rdd) + + + + + + + }
+ + @rdd.name + + @(rdd.storageLevel.description) + @rdd.numPartitions@{Utils.memoryBytesToString(rdd.memSize)}@{Utils.memoryBytesToString(rdd.diskSize)}
\ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/worker_table.scala.html b/core/src/main/twirl/spark/storage/worker_table.scala.html new file mode 100644 index 0000000000..d54b8de4cc --- /dev/null +++ b/core/src/main/twirl/spark/storage/worker_table.scala.html @@ -0,0 +1,24 @@ +@(workersStatusList: Array[spark.storage.StorageStatus], prefix: String = "") +@import spark.Utils + + + + + + + + + + + @for(status <- workersStatusList) { + + + + + + } + +
HostMemory UsageDisk Usage
@(status.blockManagerId.ip + ":" + status.blockManagerId.port) + @(Utils.memoryBytesToString(status.memUsed(prefix))) + (@(Utils.memoryBytesToString(status.memRemaining)) Total Available) + @(Utils.memoryBytesToString(status.diskUsed(prefix)))
\ No newline at end of file -- cgit v1.2.3 From 8a25d530edfa3abcdbe2effcd6bfbe484ac40acb Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 13 Nov 2012 02:16:28 -0800 Subject: Optimized checkpoint writing by reusing FileSystem object. Fixed bug in updating of checkpoint data in DStream where the checkpointed RDDs, upon recovery, were not recognized as checkpointed RDDs and therefore deleted from HDFS. Made InputStreamsSuite more robust to timing delays. --- core/src/main/scala/spark/RDD.scala | 6 +- .../main/scala/spark/streaming/Checkpoint.scala | 73 ++++++++++-------- .../src/main/scala/spark/streaming/DStream.scala | 8 +- .../src/main/scala/spark/streaming/Scheduler.scala | 28 +++++-- .../scala/spark/streaming/StreamingContext.scala | 10 +-- streaming/src/test/resources/log4j.properties | 2 +- .../scala/spark/streaming/CheckpointSuite.scala | 25 +++--- .../scala/spark/streaming/InputStreamsSuite.scala | 88 ++++++++++++---------- 8 files changed, 129 insertions(+), 111 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 63048d5df0..6af8c377b5 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -189,11 +189,7 @@ abstract class RDD[T: ClassManifest]( def getCheckpointData(): Any = { synchronized { - if (isCheckpointed) { - checkpointFile - } else { - null - } + checkpointFile } } diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index a70fb8f73a..770f7b0cc0 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -5,7 +5,7 @@ import spark.{Logging, Utils} import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration -import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} +import java.io._ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) @@ -18,8 +18,6 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val checkpointDir = ssc.checkpointDir val checkpointInterval = ssc.checkpointInterval - validate() - def validate() { assert(master != null, "Checkpoint.master is null") assert(framework != null, "Checkpoint.framework is null") @@ -27,35 +25,50 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) assert(checkpointTime != null, "Checkpoint.checkpointTime is null") logInfo("Checkpoint for time " + checkpointTime + " validated") } +} - def save(path: String) { - val file = new Path(path, "graph") - val conf = new Configuration() - val fs = file.getFileSystem(conf) - logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") - if (fs.exists(file)) { - val bkFile = new Path(file.getParent, file.getName + ".bk") - FileUtil.copy(fs, file, fs, bkFile, true, true, conf) - logDebug("Moved existing checkpoint file to " + bkFile) +/** + * Convenience class to speed up the writing of graph checkpoint to file + */ +class CheckpointWriter(checkpointDir: String) extends Logging { + val file = new Path(checkpointDir, "graph") + val conf = new Configuration() + var fs = file.getFileSystem(conf) + val maxAttempts = 3 + + def write(checkpoint: Checkpoint) { + // TODO: maybe do this in a different thread from the main stream execution thread + var attempts = 0 + while (attempts < maxAttempts) { + attempts += 1 + try { + logDebug("Saving checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'") + if (fs.exists(file)) { + val bkFile = new Path(file.getParent, file.getName + ".bk") + FileUtil.copy(fs, file, fs, bkFile, true, true, conf) + logDebug("Moved existing checkpoint file to " + bkFile) + } + val fos = fs.create(file) + val oos = new ObjectOutputStream(fos) + oos.writeObject(checkpoint) + oos.close() + logInfo("Checkpoint for time " + checkpoint.checkpointTime + " saved to file '" + file + "'") + fos.close() + return + } catch { + case ioe: IOException => + logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe) + } } - val fos = fs.create(file) - val oos = new ObjectOutputStream(fos) - oos.writeObject(this) - oos.close() - fs.close() - logInfo("Checkpoint of streaming context for time " + checkpointTime + " saved successfully to file '" + file + "'") - } - - def toBytes(): Array[Byte] = { - val bytes = Utils.serialize(this) - bytes + logError("Could not write checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'") } } -object Checkpoint extends Logging { - def load(path: String): Checkpoint = { +object CheckpointReader extends Logging { + + def read(path: String): Checkpoint = { val fs = new Path(path).getFileSystem(new Configuration()) val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk")) @@ -82,17 +95,11 @@ object Checkpoint extends Logging { logError("Error loading checkpoint from file '" + file + "'", e) } } else { - logWarning("Could not load checkpoint from file '" + file + "' as it does not exist") + logWarning("Could not read checkpoint from file '" + file + "' as it does not exist") } }) - throw new Exception("Could not load checkpoint from path '" + path + "'") - } - - def fromBytes(bytes: Array[Byte]): Checkpoint = { - val cp = Utils.deserialize[Checkpoint](bytes) - cp.validate() - cp + throw new Exception("Could not read checkpoint from path '" + path + "'") } } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index abf132e45e..7e6f73dd7d 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -289,6 +289,7 @@ extends Serializable with Logging { */ protected[streaming] def updateCheckpointData(currentTime: Time) { logInfo("Updating checkpoint data for time " + currentTime) + // Get the checkpointed RDDs from the generated RDDs val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null) .map(x => (x._1, x._2.getCheckpointData())) @@ -334,8 +335,11 @@ extends Serializable with Logging { logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs") checkpointData.foreach { case(time, data) => { - logInfo("Restoring checkpointed RDD for time " + time + " from file") - generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString))) + logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'") + val rdd = ssc.sc.objectFile[T](data.toString) + // Set the checkpoint file name to identify this RDD as a checkpointed RDD by updateCheckpointData() + rdd.checkpointFile = data.toString + generatedRDDs += ((time, rdd)) } } dependencies.foreach(_.restoreCheckpointData()) diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index de0fb1f3ad..e2dca91179 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -16,8 +16,16 @@ extends Logging { initLogging() val graph = ssc.graph + val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) + + val checkpointWriter = if (ssc.checkpointInterval != null && ssc.checkpointDir != null) { + new CheckpointWriter(ssc.checkpointDir) + } else { + null + } + val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] val timer = new RecurringTimer(clock, ssc.graph.batchDuration, generateRDDs(_)) @@ -52,19 +60,23 @@ extends Logging { logInfo("Scheduler stopped") } - def generateRDDs(time: Time) { + private def generateRDDs(time: Time) { SparkEnv.set(ssc.env) logInfo("\n-----------------------------------------------------\n") - graph.generateRDDs(time).foreach(submitJob) - logInfo("Generated RDDs for time " + time) + graph.generateRDDs(time).foreach(jobManager.runJob) graph.forgetOldRDDs(time) - if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { - ssc.doCheckpoint(time) - } + doCheckpoint(time) + logInfo("Generated RDDs for time " + time) } - def submitJob(job: Job) { - jobManager.runJob(job) + private def doCheckpoint(time: Time) { + if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { + val startTime = System.currentTimeMillis() + ssc.graph.updateCheckpointData(time) + checkpointWriter.write(new Checkpoint(ssc, time)) + val stopTime = System.currentTimeMillis() + logInfo("Checkpointing the graph took " + (stopTime - startTime) + " ms") + } } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index ab6d6e8dea..ef6a05a392 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -28,7 +28,7 @@ final class StreamingContext ( def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = this(new SparkContext(master, frameworkName, sparkHome, jars), null) - def this(path: String) = this(null, Checkpoint.load(path)) + def this(path: String) = this(null, CheckpointReader.read(path)) def this(cp_ : Checkpoint) = this(null, cp_) @@ -225,14 +225,6 @@ final class StreamingContext ( case e: Exception => logWarning("Error while stopping", e) } } - - def doCheckpoint(currentTime: Time) { - val startTime = System.currentTimeMillis() - graph.updateCheckpointData(currentTime) - new Checkpoint(this, currentTime).save(checkpointDir) - val stopTime = System.currentTimeMillis() - logInfo("Checkpointing the graph took " + (stopTime - startTime) + " ms") - } } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 02fe16866e..33774b463d 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -1,5 +1,5 @@ # Set everything to be logged to the console -log4j.rootCategory=WARN, console +log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 0ad57e38b9..b3afedf39f 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -24,7 +24,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { override def framework = "CheckpointSuite" - override def batchDuration = Milliseconds(200) + override def batchDuration = Milliseconds(500) override def checkpointDir = "checkpoint" @@ -34,7 +34,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { test("basic stream+rdd recovery") { - assert(batchDuration === Milliseconds(200), "batchDuration for this test must be 1 second") + assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") assert(checkpointInterval === batchDuration, "checkpointInterval for this test much be same as batchDuration") System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") @@ -134,9 +134,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val operation = (st: DStream[String]) => { st.map(x => (x, 1)) .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) - .checkpoint(Seconds(2)) + .checkpoint(batchDuration * 2) } - testCheckpointedOperation(input, operation, output, 3) + testCheckpointedOperation(input, operation, output, 7) } test("updateStateByKey") { @@ -148,14 +148,18 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { } st.map(x => (x, 1)) .updateStateByKey[RichInt](updateFunc) - .checkpoint(Seconds(2)) + .checkpoint(batchDuration * 2) .map(t => (t._1, t._2.self)) } - testCheckpointedOperation(input, operation, output, 3) + testCheckpointedOperation(input, operation, output, 7) } - - + /** + * Tests a streaming operation under checkpointing, by restart the operation + * from checkpoint file and verifying whether the final output is correct. + * The output is assumed to have come from a reliable queue which an replay + * data as required. + */ def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], @@ -170,8 +174,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val initialNumExpectedOutputs = initialNumBatches val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs - // Do half the computation (half the number of batches), create checkpoint file and quit - + // Do the computation for initial number of batches, create checkpoint file and quit ssc = setupStreams[U, V](input, operation) val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) @@ -193,8 +196,6 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { * Advances the manual clock on the streaming scheduler by given number of batches. * It also wait for the expected amount of time for each batch. */ - - def runStreamsWithRealDelay(ssc: StreamingContext, numBatches: Long) { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.time) diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 0957748603..3e99440226 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -16,24 +16,36 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + val testPort = 9999 + var testServer: TestServer = null + var testDir: File = null + override def checkpointDir = "checkpoint" after { FileUtils.deleteDirectory(new File(checkpointDir)) + if (testServer != null) { + testServer.stop() + testServer = null + } + if (testDir != null && testDir.exists()) { + FileUtils.deleteDirectory(testDir) + testDir = null + } } test("network input stream") { // Start the server - val serverPort = 9999 - val server = new TestServer(9999) - server.start() + testServer = new TestServer(testPort) + testServer.start() // Set up the streaming context and input streams val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - val networkStream = ssc.networkTextStream("localhost", serverPort, StorageLevel.MEMORY_AND_DISK) + val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]] val outputStream = new TestOutputStream(networkStream, outputBuffer) + def output = outputBuffer.flatMap(x => x) ssc.registerOutputStream(outputStream) ssc.start() @@ -41,21 +53,15 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq(1, 2, 3, 4, 5) val expectedOutput = input.map(_.toString) + Thread.sleep(1000) for (i <- 0 until input.size) { - server.send(input(i).toString + "\n") + testServer.send(input(i).toString + "\n") Thread.sleep(500) clock.addToTime(batchDuration.milliseconds) } - val startTime = System.currentTimeMillis() - while (outputBuffer.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - logInfo("output.size = " + outputBuffer.size + ", expectedOutput.size = " + expectedOutput.size) - Thread.sleep(100) - } Thread.sleep(1000) - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") logInfo("Stopping server") - server.stop() + testServer.stop() logInfo("Stopping context") ssc.stop() @@ -69,24 +75,24 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("--------------------------------") - assert(outputBuffer.size === expectedOutput.size) - for (i <- 0 until outputBuffer.size) { - assert(outputBuffer(i).size === 1) - assert(outputBuffer(i).head === expectedOutput(i)) + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i) === expectedOutput(i)) } } test("network input stream with checkpoint") { // Start the server - val serverPort = 9999 - val server = new TestServer(9999) - server.start() + testServer = new TestServer(testPort) + testServer.start() // Set up the streaming context and input streams var ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) ssc.checkpoint(checkpointDir, checkpointInterval) - val networkStream = ssc.networkTextStream("localhost", serverPort, StorageLevel.MEMORY_AND_DISK) + val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) var outputStream = new TestOutputStream(networkStream, new ArrayBuffer[Seq[String]]) ssc.registerOutputStream(outputStream) ssc.start() @@ -94,7 +100,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Feed data to the server to send to the network receiver var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] for (i <- Seq(1, 2, 3)) { - server.send(i.toString + "\n") + testServer.send(i.toString + "\n") Thread.sleep(100) clock.addToTime(batchDuration.milliseconds) } @@ -109,7 +115,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { ssc.start() clock = ssc.scheduler.clock.asInstanceOf[ManualClock] for (i <- Seq(4, 5, 6)) { - server.send(i.toString + "\n") + testServer.send(i.toString + "\n") Thread.sleep(100) clock.addToTime(batchDuration.milliseconds) } @@ -120,12 +126,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("file input stream") { + // Create a temporary directory - val dir = { + testDir = { var temp = File.createTempFile(".temp.", Random.nextInt().toString) temp.delete() temp.mkdirs() - temp.deleteOnExit() logInfo("Created temp dir " + temp) temp } @@ -133,10 +139,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Set up the streaming context and input streams val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - val filestream = ssc.textFileStream(dir.toString) + val filestream = ssc.textFileStream(testDir.toString) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] def output = outputBuffer.flatMap(x => x) - val outputStream = new TestOutputStream(filestream, outputBuffer) ssc.registerOutputStream(outputStream) ssc.start() @@ -147,16 +152,16 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val expectedOutput = input.map(_.toString) Thread.sleep(1000) for (i <- 0 until input.size) { - FileUtils.writeStringToFile(new File(dir, i.toString), input(i).toString + "\n") - Thread.sleep(100) + FileUtils.writeStringToFile(new File(testDir, i.toString), input(i).toString + "\n") + Thread.sleep(500) clock.addToTime(batchDuration.milliseconds) - Thread.sleep(100) + //Thread.sleep(100) } val startTime = System.currentTimeMillis() - while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - //println("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size) + /*while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + logInfo("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size) Thread.sleep(100) - } + }*/ Thread.sleep(1000) val timeTaken = System.currentTimeMillis() - startTime assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") @@ -165,14 +170,16 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether data received by Spark Streaming was as expected logInfo("--------------------------------") - logInfo("output.size = " + output.size) + logInfo("output.size = " + outputBuffer.size) logInfo("output") - output.foreach(x => logInfo("[" + x.mkString(",") + "]")) + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("expected output.size = " + expectedOutput.size) logInfo("expected output") expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("--------------------------------") + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) assert(output.size === expectedOutput.size) for (i <- 0 until output.size) { assert(output(i).size === 1) @@ -182,12 +189,11 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { test("file input stream with checkpoint") { // Create a temporary directory - val dir = { + testDir = { var temp = File.createTempFile(".temp.", Random.nextInt().toString) temp.delete() temp.mkdirs() - temp.deleteOnExit() - println("Created temp dir " + temp) + logInfo("Created temp dir " + temp) temp } @@ -195,7 +201,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { var ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) ssc.checkpoint(checkpointDir, checkpointInterval) - val filestream = ssc.textFileStream(dir.toString) + val filestream = ssc.textFileStream(testDir.toString) var outputStream = new TestOutputStream(filestream, new ArrayBuffer[Seq[String]]) ssc.registerOutputStream(outputStream) ssc.start() @@ -204,7 +210,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] Thread.sleep(1000) for (i <- Seq(1, 2, 3)) { - FileUtils.writeStringToFile(new File(dir, i.toString), i.toString + "\n") + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") Thread.sleep(100) clock.addToTime(batchDuration.milliseconds) } @@ -221,7 +227,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { clock = ssc.scheduler.clock.asInstanceOf[ManualClock] Thread.sleep(500) for (i <- Seq(4, 5, 6)) { - FileUtils.writeStringToFile(new File(dir, i.toString), i.toString + "\n") + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") Thread.sleep(100) clock.addToTime(batchDuration.milliseconds) } -- cgit v1.2.3 From 10c1abcb6ac42b248818fa585a9ad49c2fa4851a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 17 Nov 2012 17:27:00 -0800 Subject: Fixed checkpointing bug in CoGroupedRDD. CoGroupSplits kept around the RDD splits of its parent RDDs, thus checkpointing its parents did not release the references to the parent splits. --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 18 ++++++++++---- core/src/test/scala/spark/CheckpointSuite.scala | 28 ++++++++++++++++++++++ .../src/main/scala/spark/streaming/DStream.scala | 4 ++-- .../main/scala/spark/streaming/DStreamGraph.scala | 2 +- 4 files changed, 45 insertions(+), 7 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index a313ebcbe8..94ef1b56e8 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -12,9 +12,20 @@ import spark.RDD import spark.ShuffleDependency import spark.SparkEnv import spark.Split +import java.io.{ObjectOutputStream, IOException} private[spark] sealed trait CoGroupSplitDep extends Serializable -private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep +private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, var split: Split = null) + extends CoGroupSplitDep { + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + rdd.synchronized { + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() + } + } +} private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep private[spark] @@ -55,7 +66,6 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) @transient var splits_ : Array[Split] = { - val firstRdd = rdds.head val array = new Array[Split](part.numPartitions) for (i <- 0 until array.size) { array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) => @@ -63,7 +73,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) case s: ShuffleDependency[_, _] => new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep case _ => - new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep + new NarrowCoGroupSplitDep(r, i): CoGroupSplitDep } }.toList) } @@ -82,7 +92,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) } for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, itsSplit) => { + case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => { // Read them from the parent for ((k, v) <- rdd.iterator(itsSplit)) { getSeq(k.asInstanceOf[K])(depNum) += v diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 57dc43ddac..8622ce92aa 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -6,6 +6,7 @@ import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD} import spark.SparkContext._ import storage.StorageLevel import java.util.concurrent.Semaphore +import collection.mutable.ArrayBuffer class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { initLogging() @@ -92,6 +93,33 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val rdd2 = sc.makeRDD(5 to 6, 4).map(x => (x % 2, 1)) testCheckpointing(rdd1 => rdd1.map(x => (x % 2, 1)).cogroup(rdd2)) testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2)) + + // Special test to make sure that the CoGroupSplit of CoGroupedRDD do not + // hold on to the splits of its parent RDDs, as the splits of parent RDDs + // may change while checkpointing. Rather the splits of parent RDDs must + // be fetched at the time of serialization to ensure the latest splits to + // be sent along with the task. + + val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) + + val ones = sc.parallelize(1 to 100, 1).map(x => (x,1)) + val reduced = ones.reduceByKey(_ + _) + val seqOfCogrouped = new ArrayBuffer[RDD[(Int, Int)]]() + seqOfCogrouped += reduced.cogroup(ones).mapValues[Int](add) + for(i <- 1 to 10) { + seqOfCogrouped += seqOfCogrouped.last.cogroup(ones).mapValues(add) + } + val finalCogrouped = seqOfCogrouped.last + val intermediateCogrouped = seqOfCogrouped(5) + + val bytesBeforeCheckpoint = Utils.serialize(finalCogrouped.splits) + intermediateCogrouped.checkpoint() + finalCogrouped.count() + sleep(intermediateCogrouped) + val bytesAfterCheckpoint = Utils.serialize(finalCogrouped.splits) + println("Before = " + bytesBeforeCheckpoint.size + ", after = " + bytesAfterCheckpoint.size) + assert(bytesAfterCheckpoint.size < bytesBeforeCheckpoint.size, + "CoGroupedSplits still holds on to the splits of its parent RDDs") } /** diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 76cdf8c464..13770aa8fd 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -226,11 +226,11 @@ extends Serializable with Logging { case Some(newRDD) => if (storageLevel != StorageLevel.NONE) { newRDD.persist(storageLevel) - logInfo("Persisting RDD for time " + time + " to " + storageLevel + " at time " + time) + logInfo("Persisting RDD " + newRDD.id + " for time " + time + " to " + storageLevel + " at time " + time) } if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { newRDD.checkpoint() - logInfo("Marking RDD " + newRDD + " for time " + time + " for checkpointing at time " + time) + logInfo("Marking RDD " + newRDD.id + " for time " + time + " for checkpointing at time " + time) } generatedRDDs.put(time, newRDD) Some(newRDD) diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index 246522838a..bd8c033eab 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -105,7 +105,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { private[streaming] def validate() { this.synchronized { assert(batchDuration != null, "Batch duration has not been set") - assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low") + //assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + " is very low") assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute") } } -- cgit v1.2.3 From c97ebf64377e853ab7c616a103869a4417f25954 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 19 Nov 2012 23:22:07 +0000 Subject: Fixed bug in the number of splits in RDD after checkpointing. Modified reduceByKeyAndWindow (naive) computation from window+reduceByKey to reduceByKey+window+reduceByKey. --- conf/streaming-env.sh.template | 2 +- core/src/main/scala/spark/RDD.scala | 3 ++- streaming/src/main/scala/spark/streaming/DStream.scala | 3 ++- streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala | 6 +++++- streaming/src/main/scala/spark/streaming/Scheduler.scala | 2 +- streaming/src/main/scala/spark/streaming/WindowedDStream.scala | 3 +++ 6 files changed, 14 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/conf/streaming-env.sh.template b/conf/streaming-env.sh.template index 6b4094c515..1ea9ba5541 100755 --- a/conf/streaming-env.sh.template +++ b/conf/streaming-env.sh.template @@ -11,7 +11,7 @@ SPARK_JAVA_OPTS+=" -XX:+UseConcMarkSweepGC" -# Using of Kryo serialization can improve serialization performance +# Using Kryo serialization can improve serialization performance # and therefore the throughput of the Spark Streaming programs. However, # using Kryo serialization with custom classes may required you to # register the classes with Kryo. Refer to the Spark documentation diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 6af8c377b5..8af6c9bd6a 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -222,12 +222,13 @@ abstract class RDD[T: ClassManifest]( rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString rdd.saveAsObjectFile(checkpointFile) rdd.synchronized { - rdd.checkpointRDD = context.objectFile[T](checkpointFile) + rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size) rdd.checkpointRDDSplits = rdd.checkpointRDD.splits rdd.changeDependencies(rdd.checkpointRDD) rdd.shouldCheckpoint = false rdd.isCheckpointInProgress = false rdd.isCheckpointed = true + println("Done checkpointing RDD " + rdd.id + ", " + rdd) } } } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 13770aa8fd..26d5ce9198 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -321,7 +321,8 @@ extends Serializable with Logging { } } } - logInfo("Updated checkpoint data for time " + currentTime) + logInfo("Updated checkpoint data for time " + currentTime + ", " + checkpointData.size + " checkpoints, " + + "[" + checkpointData.mkString(",") + "]") } /** diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index e09d27d34f..720e63bba0 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -4,6 +4,7 @@ import spark.streaming.StreamingContext._ import spark.{Manifests, RDD, Partitioner, HashPartitioner} import spark.SparkContext._ +import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer @@ -115,7 +116,10 @@ extends Serializable { slideTime: Time, partitioner: Partitioner ): DStream[(K, V)] = { - self.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), partitioner) + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + self.reduceByKey(cleanedReduceFunc, partitioner) + .window(windowTime, slideTime) + .reduceByKey(cleanedReduceFunc, partitioner) } // This method is the efficient sliding window reduce operation, diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index e2dca91179..014021be61 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -17,7 +17,7 @@ extends Logging { val graph = ssc.graph - val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt + val concurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) val checkpointWriter = if (ssc.checkpointInterval != null && ssc.checkpointDir != null) { diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala index ce89a3f99b..e4d2a634f5 100644 --- a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala @@ -2,6 +2,7 @@ package spark.streaming import spark.RDD import spark.rdd.UnionRDD +import spark.storage.StorageLevel class WindowedDStream[T: ClassManifest]( @@ -18,6 +19,8 @@ class WindowedDStream[T: ClassManifest]( throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + parent.persist(StorageLevel.MEMORY_ONLY_SER) + def windowTime: Time = _windowTime override def dependencies = List(parent) -- cgit v1.2.3 From b18d70870a33a4783c6b3b787bef9b0eec30bce0 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 27 Nov 2012 15:08:49 -0800 Subject: Modified bunch HashMaps in Spark to use TimeStampedHashMap and made various modules use CleanupTask to periodically clean up metadata. --- core/src/main/scala/spark/CacheTracker.scala | 6 +- core/src/main/scala/spark/MapOutputTracker.scala | 27 ++++--- .../main/scala/spark/scheduler/DAGScheduler.scala | 13 +++- .../scala/spark/scheduler/ShuffleMapTask.scala | 6 +- core/src/main/scala/spark/util/CleanupTask.scala | 31 ++++++++ .../main/scala/spark/util/TimeStampedHashMap.scala | 87 ++++++++++++++++++++++ .../scala/spark/streaming/StreamingContext.scala | 13 +++- 7 files changed, 165 insertions(+), 18 deletions(-) create mode 100644 core/src/main/scala/spark/util/CleanupTask.scala create mode 100644 core/src/main/scala/spark/util/TimeStampedHashMap.scala (limited to 'core') diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index c5db6ce63a..0ee59bee0f 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -14,6 +14,7 @@ import scala.collection.mutable.HashSet import spark.storage.BlockManager import spark.storage.StorageLevel +import util.{CleanupTask, TimeStampedHashMap} private[spark] sealed trait CacheTrackerMessage @@ -30,7 +31,7 @@ private[spark] case object StopCacheTracker extends CacheTrackerMessage private[spark] class CacheTrackerActor extends Actor with Logging { // TODO: Should probably store (String, CacheType) tuples - private val locs = new HashMap[Int, Array[List[String]]] + private val locs = new TimeStampedHashMap[Int, Array[List[String]]] /** * A map from the slave's host name to its cache size. @@ -38,6 +39,8 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] + private val cleanupTask = new CleanupTask("CacheTracker", locs.cleanup) + private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host) @@ -86,6 +89,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { case StopCacheTracker => logInfo("Stopping CacheTrackerActor") sender ! true + cleanupTask.cancel() context.stop(self) } } diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 45441aa5e5..d0be1bb913 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -17,6 +17,7 @@ import scala.collection.mutable.HashSet import scheduler.MapStatus import spark.storage.BlockManagerId import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import util.{CleanupTask, TimeStampedHashMap} private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) @@ -43,7 +44,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea val timeout = 10.seconds - var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]] + var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. @@ -52,7 +53,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea // Cache a serialized version of the output statuses for each shuffle to send them out faster var cacheGeneration = generation - val cachedSerializedStatuses = new HashMap[Int, Array[Byte]] + val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] var trackerActor: ActorRef = if (isMaster) { val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName) @@ -63,6 +64,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea actorSystem.actorFor(url) } + val cleanupTask = new CleanupTask("MapOutputTracker", this.cleanup) + // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. def askTracker(message: Any): Any = { @@ -83,14 +86,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.get(shuffleId) != null) { + if (mapStatuses.get(shuffleId) != None) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - var array = mapStatuses.get(shuffleId) + var array = mapStatuses(shuffleId) array.synchronized { array(mapId) = status } @@ -107,7 +110,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var array = mapStatuses.get(shuffleId) + var array = mapStatuses(shuffleId) if (array != null) { array.synchronized { if (array(mapId).address == bmAddress) { @@ -125,7 +128,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { - val statuses = mapStatuses.get(shuffleId) + val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") fetching.synchronized { @@ -138,7 +141,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea case e: InterruptedException => } } - return mapStatuses.get(shuffleId).map(status => + return mapStatuses(shuffleId).map(status => (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId)))) } else { fetching += shuffleId @@ -164,9 +167,15 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } } + def cleanup(cleanupTime: Long) { + mapStatuses.cleanup(cleanupTime) + cachedSerializedStatuses.cleanup(cleanupTime) + } + def stop() { communicate(StopMapOutputTracker) mapStatuses.clear() + cleanupTask.cancel() trackerActor = null } @@ -192,7 +201,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea generationLock.synchronized { if (newGen > generation) { logInfo("Updating generation to " + newGen + " and clearing cache") - mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]] + mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] generation = newGen } } @@ -210,7 +219,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea case Some(bytes) => return bytes case None => - statuses = mapStatuses.get(shuffleId) + statuses = mapStatuses(shuffleId) generationGotten = generation } } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index aaaed59c4a..3af877b817 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -14,6 +14,7 @@ import spark.partial.ApproximateEvaluator import spark.partial.PartialResult import spark.storage.BlockManagerMaster import spark.storage.BlockManagerId +import util.{CleanupTask, TimeStampedHashMap} /** * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for @@ -61,9 +62,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val nextStageId = new AtomicInteger(0) - val idToStage = new HashMap[Int, Stage] + val idToStage = new TimeStampedHashMap[Int, Stage] - val shuffleToMapStage = new HashMap[Int, Stage] + val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] var cacheLocs = new HashMap[Int, Array[List[String]]] @@ -83,6 +84,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] + val cleanupTask = new CleanupTask("DAGScheduler", this.cleanup) + // Start a thread to run the DAGScheduler event loop new Thread("DAGScheduler") { setDaemon(true) @@ -591,8 +594,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return Nil } + def cleanup(cleanupTime: Long) { + idToStage.cleanup(cleanupTime) + shuffleToMapStage.cleanup(cleanupTime) + } + def stop() { eventQueue.put(StopDAGScheduler) + cleanupTask.cancel() taskSched.stop() } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 60105c42b6..fbf618c906 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -14,17 +14,19 @@ import com.ning.compress.lzf.LZFOutputStream import spark._ import spark.storage._ +import util.{TimeStampedHashMap, CleanupTask} private[spark] object ShuffleMapTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new JHashMap[Int, Array[Byte]] + val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] + val cleanupTask = new CleanupTask("ShuffleMapTask", serializedInfoCache.cleanup) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { - val old = serializedInfoCache.get(stageId) + val old = serializedInfoCache.get(stageId).orNull if (old != null) { return old } else { diff --git a/core/src/main/scala/spark/util/CleanupTask.scala b/core/src/main/scala/spark/util/CleanupTask.scala new file mode 100644 index 0000000000..ccc28803e0 --- /dev/null +++ b/core/src/main/scala/spark/util/CleanupTask.scala @@ -0,0 +1,31 @@ +package spark.util + +import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} +import java.util.{TimerTask, Timer} +import spark.Logging + +class CleanupTask(name: String, cleanupFunc: (Long) => Unit) extends Logging { + val delayMins = System.getProperty("spark.cleanup.delay", "-100").toInt + val periodMins = System.getProperty("spark.cleanup.period", (delayMins / 10).toString).toInt + val timer = new Timer(name + " cleanup timer", true) + val task = new TimerTask { + def run() { + try { + if (delayMins > 0) { + + cleanupFunc(System.currentTimeMillis() - (delayMins * 60 * 1000)) + logInfo("Ran cleanup task for " + name) + } + } catch { + case e: Exception => logError("Error running cleanup task for " + name, e) + } + } + } + if (periodMins > 0) { + timer.schedule(task, periodMins * 60 * 1000, periodMins * 60 * 1000) + } + + def cancel() { + timer.cancel() + } +} diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala new file mode 100644 index 0000000000..7a22b80a20 --- /dev/null +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -0,0 +1,87 @@ +package spark.util + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{HashMap, Map} +import java.util.concurrent.ConcurrentHashMap + +/** + * This is a custom implementation of scala.collection.mutable.Map which stores the insertion + * time stamp along with each key-value pair. Key-value pairs that are older than a particular + * threshold time can them be removed using the cleanup method. This is intended to be a drop-in + * replacement of scala.collection.mutable.HashMap. + */ +class TimeStampedHashMap[A, B] extends Map[A, B]() { + val internalMap = new ConcurrentHashMap[A, (B, Long)]() + + def get(key: A): Option[B] = { + val value = internalMap.get(key) + if (value != null) Some(value._1) else None + } + + def iterator: Iterator[(A, B)] = { + val jIterator = internalMap.entrySet().iterator() + jIterator.map(kv => (kv.getKey, kv.getValue._1)) + } + + override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { + val newMap = new TimeStampedHashMap[A, B1] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.put(kv._1, (kv._2, currentTime)) + newMap + } + + override def - (key: A): Map[A, B] = { + internalMap.remove(key) + this + } + + override def += (kv: (A, B)): this.type = { + internalMap.put(kv._1, (kv._2, currentTime)) + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def update(key: A, value: B) { + this += ((key, value)) + } + + override def apply(key: A): B = { + val value = internalMap.get(key) + if (value == null) throw new NoSuchElementException() + value._1 + } + + override def filter(p: ((A, B)) => Boolean): Map[A, B] = { + internalMap.map(kv => (kv._1, kv._2._1)).filter(p) + } + + override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() + + override def size(): Int = internalMap.size() + + override def foreach[U](f: ((A, B)) => U): Unit = { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val kv = (entry.getKey, entry.getValue._1) + f(kv) + } + } + + def cleanup(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue._2 < threshTime) { + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() + +} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 4a41f2f516..58123dc82c 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -43,7 +43,7 @@ class StreamingContext private ( * @param batchDuration The time interval at which streaming data will be divided into batches */ def this(master: String, frameworkName: String, batchDuration: Time) = - this(new SparkContext(master, frameworkName), null, batchDuration) + this(StreamingContext.createNewSparkContext(master, frameworkName), null, batchDuration) /** * Recreates the StreamingContext from a checkpoint file. @@ -214,11 +214,8 @@ class StreamingContext private ( "Checkpoint directory has been set, but the graph checkpointing interval has " + "not been set. Please use StreamingContext.checkpoint() to set the interval." ) - - } - /** * This function starts the execution of the streams. */ @@ -265,6 +262,14 @@ class StreamingContext private ( object StreamingContext { + + def createNewSparkContext(master: String, frameworkName: String): SparkContext = { + if (System.getProperty("spark.cleanup.delay", "-1").toInt < 0) { + System.setProperty("spark.cleanup.delay", "60") + } + new SparkContext(master, frameworkName) + } + implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = { new PairDStreamFunctions[K, V](stream) } -- cgit v1.2.3 From d5e7aad039603a8a02d11f9ebda001422ca4c341 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 28 Nov 2012 08:36:55 +0000 Subject: Bug fixes --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 11 ++++++++++- core/src/main/scala/spark/util/CleanupTask.scala | 17 +++++++++-------- .../main/scala/spark/streaming/StreamingContext.scala | 2 +- 3 files changed, 20 insertions(+), 10 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 3af877b817..affacb43ca 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -78,7 +78,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures - val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage + val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits val activeJobs = new HashSet[ActiveJob] @@ -595,8 +595,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } def cleanup(cleanupTime: Long) { + var sizeBefore = idToStage.size idToStage.cleanup(cleanupTime) + logInfo("idToStage " + sizeBefore + " --> " + idToStage.size) + + sizeBefore = shuffleToMapStage.size shuffleToMapStage.cleanup(cleanupTime) + logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) + + sizeBefore = pendingTasks.size + pendingTasks.cleanup(cleanupTime) + logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) } def stop() { diff --git a/core/src/main/scala/spark/util/CleanupTask.scala b/core/src/main/scala/spark/util/CleanupTask.scala index ccc28803e0..a4357c62c6 100644 --- a/core/src/main/scala/spark/util/CleanupTask.scala +++ b/core/src/main/scala/spark/util/CleanupTask.scala @@ -5,24 +5,25 @@ import java.util.{TimerTask, Timer} import spark.Logging class CleanupTask(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delayMins = System.getProperty("spark.cleanup.delay", "-100").toInt - val periodMins = System.getProperty("spark.cleanup.period", (delayMins / 10).toString).toInt + val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt + val periodSeconds = math.max(10, delaySeconds / 10) val timer = new Timer(name + " cleanup timer", true) val task = new TimerTask { def run() { try { - if (delayMins > 0) { - - cleanupFunc(System.currentTimeMillis() - (delayMins * 60 * 1000)) + if (delaySeconds > 0) { + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) logInfo("Ran cleanup task for " + name) - } + } } catch { case e: Exception => logError("Error running cleanup task for " + name, e) } } } - if (periodMins > 0) { - timer.schedule(task, periodMins * 60 * 1000, periodMins * 60 * 1000) + if (periodSeconds > 0) { + logInfo("Starting cleanup task for " + name + " with delay of " + delaySeconds + " seconds and " + + "period of " + periodSeconds + " secs") + timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) } def cancel() { diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 58123dc82c..90dd560752 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -264,7 +264,7 @@ class StreamingContext private ( object StreamingContext { def createNewSparkContext(master: String, frameworkName: String): SparkContext = { - if (System.getProperty("spark.cleanup.delay", "-1").toInt < 0) { + if (System.getProperty("spark.cleanup.delay", "-1").toDouble < 0) { System.setProperty("spark.cleanup.delay", "60") } new SparkContext(master, frameworkName) -- cgit v1.2.3 From e463ae492068d2922e1d50c051a87f8010953dff Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 28 Nov 2012 14:05:01 -0800 Subject: Modified StorageLevel and BlockManagerId to cache common objects and use cached object while deserializing. --- .../main/scala/spark/storage/BlockManager.scala | 28 +------------ .../main/scala/spark/storage/BlockManagerId.scala | 48 ++++++++++++++++++++++ .../main/scala/spark/storage/StorageLevel.scala | 28 ++++++++++++- .../scala/spark/storage/BlockManagerSuite.scala | 26 ++++++++++++ 4 files changed, 101 insertions(+), 29 deletions(-) create mode 100644 core/src/main/scala/spark/storage/BlockManagerId.scala (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 70d6d8369d..e4aa9247a3 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -20,33 +20,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import sun.nio.ch.DirectBuffer -private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { - def this() = this(null, 0) // For deserialization only - - def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) - - override def writeExternal(out: ObjectOutput) { - out.writeUTF(ip) - out.writeInt(port) - } - - override def readExternal(in: ObjectInput) { - ip = in.readUTF() - port = in.readInt() - } - - override def toString = "BlockManagerId(" + ip + ", " + port + ")" - - override def hashCode = ip.hashCode * 41 + port - - override def equals(that: Any) = that match { - case id: BlockManagerId => port == id.port && ip == id.ip - case _ => false - } -} - - -private[spark] +private[spark] case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message) diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala new file mode 100644 index 0000000000..4933cc6606 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -0,0 +1,48 @@ +package spark.storage + +import java.io.{IOException, ObjectOutput, ObjectInput, Externalizable} +import java.util.concurrent.ConcurrentHashMap + +private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { + def this() = this(null, 0) // For deserialization only + + def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) + + override def writeExternal(out: ObjectOutput) { + out.writeUTF(ip) + out.writeInt(port) + } + + override def readExternal(in: ObjectInput) { + ip = in.readUTF() + port = in.readInt() + } + + @throws(classOf[IOException]) + private def readResolve(): Object = { + BlockManagerId.getCachedBlockManagerId(this) + } + + + override def toString = "BlockManagerId(" + ip + ", " + port + ")" + + override def hashCode = ip.hashCode * 41 + port + + override def equals(that: Any) = that match { + case id: BlockManagerId => port == id.port && ip == id.ip + case _ => false + } +} + +object BlockManagerId { + val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() + + def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { + if (blockManagerIdCache.containsKey(id)) { + blockManagerIdCache.get(id) + } else { + blockManagerIdCache.put(id, id) + id + } + } +} \ No newline at end of file diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index c497f03e0c..eb88eb2759 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -1,6 +1,9 @@ package spark.storage -import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput} +import collection.mutable +import util.Random +import collection.mutable.ArrayBuffer /** * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, @@ -17,7 +20,8 @@ class StorageLevel( extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - + assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + def this(flags: Int, replication: Int) { this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } @@ -27,6 +31,10 @@ class StorageLevel( override def clone(): StorageLevel = new StorageLevel( this.useDisk, this.useMemory, this.deserialized, this.replication) + override def hashCode(): Int = { + toInt * 41 + replication + } + override def equals(other: Any): Boolean = other match { case s: StorageLevel => s.useDisk == useDisk && @@ -66,6 +74,11 @@ class StorageLevel( replication = in.readByte() } + @throws(classOf[IOException]) + private def readResolve(): Object = { + StorageLevel.getCachedStorageLevel(this) + } + override def toString: String = "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) } @@ -82,4 +95,15 @@ object StorageLevel { val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2) val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) + + val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() + + def getCachedStorageLevel(level: StorageLevel): StorageLevel = { + if (storageLevelCache.containsKey(level)) { + storageLevelCache.get(level) + } else { + storageLevelCache.put(level, level) + level + } + } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 0e78228134..a2d5e39859 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -57,6 +57,32 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } } + test("StorageLevel object caching") { + val level1 = new StorageLevel(false, false, false, 3) + val level2 = new StorageLevel(false, false, false, 3) + val bytes1 = spark.Utils.serialize(level1) + val level1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val bytes2 = spark.Utils.serialize(level2) + val level2_ = spark.Utils.deserialize[StorageLevel](bytes2) + assert(level1_ === level1, "Deserialized level1 not same as original level1") + assert(level2_ === level2, "Deserialized level2 not same as original level1") + assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2") + assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1") + } + + test("BlockManagerId object caching") { + val id1 = new StorageLevel(false, false, false, 3) + val id2 = new StorageLevel(false, false, false, 3) + val bytes1 = spark.Utils.serialize(id1) + val id1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val bytes2 = spark.Utils.serialize(id2) + val id2_ = spark.Utils.deserialize[StorageLevel](bytes2) + assert(id1_ === id1, "Deserialized id1 not same as original id1") + assert(id2_ === id2, "Deserialized id2 not same as original id1") + assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2") + assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1") + } + test("master + 1 manager interaction") { store = new BlockManager(master, serializer, 2000) val a1 = new Array[Byte](400) -- cgit v1.2.3 From 9e9e9e1d898387a1996e4c57128bafadb5938a9b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 28 Nov 2012 18:48:14 -0800 Subject: Renamed CleanupTask to MetadataCleaner. --- core/src/main/scala/spark/CacheTracker.scala | 6 ++-- core/src/main/scala/spark/MapOutputTracker.scala | 6 ++-- .../main/scala/spark/scheduler/DAGScheduler.scala | 6 ++-- .../scala/spark/scheduler/ShuffleMapTask.scala | 5 ++-- core/src/main/scala/spark/util/CleanupTask.scala | 32 ---------------------- .../main/scala/spark/util/MetadataCleaner.scala | 32 ++++++++++++++++++++++ 6 files changed, 44 insertions(+), 43 deletions(-) delete mode 100644 core/src/main/scala/spark/util/CleanupTask.scala create mode 100644 core/src/main/scala/spark/util/MetadataCleaner.scala (limited to 'core') diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 0ee59bee0f..9888f061d9 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -14,7 +14,7 @@ import scala.collection.mutable.HashSet import spark.storage.BlockManager import spark.storage.StorageLevel -import util.{CleanupTask, TimeStampedHashMap} +import util.{MetadataCleaner, TimeStampedHashMap} private[spark] sealed trait CacheTrackerMessage @@ -39,7 +39,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] - private val cleanupTask = new CleanupTask("CacheTracker", locs.cleanup) + private val metadataCleaner = new MetadataCleaner("CacheTracker", locs.cleanup) private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) @@ -89,7 +89,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { case StopCacheTracker => logInfo("Stopping CacheTrackerActor") sender ! true - cleanupTask.cancel() + metadataCleaner.cancel() context.stop(self) } } diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index d0be1bb913..20ff5431af 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -17,7 +17,7 @@ import scala.collection.mutable.HashSet import scheduler.MapStatus import spark.storage.BlockManagerId import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import util.{CleanupTask, TimeStampedHashMap} +import util.{MetadataCleaner, TimeStampedHashMap} private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) @@ -64,7 +64,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea actorSystem.actorFor(url) } - val cleanupTask = new CleanupTask("MapOutputTracker", this.cleanup) + val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup) // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. @@ -175,7 +175,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea def stop() { communicate(StopMapOutputTracker) mapStatuses.clear() - cleanupTask.cancel() + metadataCleaner.cancel() trackerActor = null } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index affacb43ca..4b2570fa2b 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -14,7 +14,7 @@ import spark.partial.ApproximateEvaluator import spark.partial.PartialResult import spark.storage.BlockManagerMaster import spark.storage.BlockManagerId -import util.{CleanupTask, TimeStampedHashMap} +import util.{MetadataCleaner, TimeStampedHashMap} /** * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for @@ -84,7 +84,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] - val cleanupTask = new CleanupTask("DAGScheduler", this.cleanup) + val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) // Start a thread to run the DAGScheduler event loop new Thread("DAGScheduler") { @@ -610,7 +610,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def stop() { eventQueue.put(StopDAGScheduler) - cleanupTask.cancel() + metadataCleaner.cancel() taskSched.stop() } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index fbf618c906..683f5ebec3 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -14,7 +14,7 @@ import com.ning.compress.lzf.LZFOutputStream import spark._ import spark.storage._ -import util.{TimeStampedHashMap, CleanupTask} +import util.{TimeStampedHashMap, MetadataCleaner} private[spark] object ShuffleMapTask { @@ -22,7 +22,8 @@ private[spark] object ShuffleMapTask { // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val cleanupTask = new CleanupTask("ShuffleMapTask", serializedInfoCache.cleanup) + + val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.cleanup) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/spark/util/CleanupTask.scala b/core/src/main/scala/spark/util/CleanupTask.scala deleted file mode 100644 index a4357c62c6..0000000000 --- a/core/src/main/scala/spark/util/CleanupTask.scala +++ /dev/null @@ -1,32 +0,0 @@ -package spark.util - -import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} -import java.util.{TimerTask, Timer} -import spark.Logging - -class CleanupTask(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt - val periodSeconds = math.max(10, delaySeconds / 10) - val timer = new Timer(name + " cleanup timer", true) - val task = new TimerTask { - def run() { - try { - if (delaySeconds > 0) { - cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) - logInfo("Ran cleanup task for " + name) - } - } catch { - case e: Exception => logError("Error running cleanup task for " + name, e) - } - } - } - if (periodSeconds > 0) { - logInfo("Starting cleanup task for " + name + " with delay of " + delaySeconds + " seconds and " - + "period of " + periodSeconds + " secs") - timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) - } - - def cancel() { - timer.cancel() - } -} diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala new file mode 100644 index 0000000000..71ac39864e --- /dev/null +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -0,0 +1,32 @@ +package spark.util + +import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} +import java.util.{TimerTask, Timer} +import spark.Logging + +class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { + val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt + val periodSeconds = math.max(10, delaySeconds / 10) + val timer = new Timer(name + " cleanup timer", true) + val task = new TimerTask { + def run() { + try { + if (delaySeconds > 0) { + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) + logInfo("Ran metadata cleaner for " + name) + } + } catch { + case e: Exception => logError("Error running cleanup task for " + name, e) + } + } + } + if (periodSeconds > 0) { + logInfo("Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and " + + "period of " + periodSeconds + " secs") + timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) + } + + def cancel() { + timer.cancel() + } +} -- cgit v1.2.3 From c9789751bfc496d24e8369a0035d57f0ed8dcb58 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 28 Nov 2012 23:18:24 -0800 Subject: Added metadata cleaner to BlockManager to remove old blocks completely. --- .../main/scala/spark/storage/BlockManager.scala | 47 ++++++++++++++++------ .../scala/spark/storage/BlockManagerMaster.scala | 1 + 2 files changed, 36 insertions(+), 12 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index e4aa9247a3..1e36578e1a 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -10,12 +10,12 @@ import java.nio.{MappedByteBuffer, ByteBuffer} import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} -import scala.collection.JavaConversions._ import spark.{CacheTracker, Logging, SizeEstimator, SparkException, Utils} import spark.network._ import spark.serializer.Serializer -import spark.util.ByteBufferInputStream +import spark.util.{MetadataCleaner, TimeStampedHashMap, ByteBufferInputStream} + import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import sun.nio.ch.DirectBuffer @@ -51,7 +51,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - private val blockInfo = new ConcurrentHashMap[String, BlockInfo](1000) + private val blockInfo = new TimeStampedHashMap[String, BlockInfo]() private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private[storage] val diskStore: BlockStore = @@ -80,6 +80,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val host = System.getProperty("spark.hostname", Utils.localHostName()) + val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks) initialize() /** @@ -102,8 +103,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * Get storage level of local block. If no info exists for the block, then returns null. */ def getLevel(blockId: String): StorageLevel = { - val info = blockInfo.get(blockId) - if (info != null) info.level else null + blockInfo.get(blockId).map(_.level).orNull } /** @@ -113,9 +113,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def reportBlockStatus(blockId: String) { val (curLevel, inMemSize, onDiskSize) = blockInfo.get(blockId) match { - case null => + case None => (StorageLevel.NONE, 0L, 0L) - case info => + case Some(info) => info.synchronized { info.level match { case null => @@ -173,7 +173,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { info.waitForReady() // In case the block is still being put() by another thread @@ -258,7 +258,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { info.waitForReady() // In case the block is still being put() by another thread @@ -517,7 +517,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new IllegalArgumentException("Storage level is null or invalid") } - val oldBlock = blockInfo.get(blockId) + val oldBlock = blockInfo.get(blockId).orNull if (oldBlock != null) { logWarning("Block " + blockId + " already exists on this machine; not re-adding it") oldBlock.waitForReady() @@ -618,7 +618,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new IllegalArgumentException("Storage level is null or invalid") } - if (blockInfo.containsKey(blockId)) { + if (blockInfo.contains(blockId)) { logWarning("Block " + blockId + " already exists on this machine; not re-adding it") return } @@ -740,7 +740,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { logInfo("Dropping block " + blockId + " from memory") - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { val level = info.level @@ -767,6 +767,29 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } + def dropOldBlocks(cleanupTime: Long) { + logInfo("Dropping blocks older than " + cleanupTime) + val iterator = blockInfo.internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) + if (time < cleanupTime) { + info.synchronized { + val level = info.level + if (level.useMemory) { + memoryStore.remove(id) + } + if (level.useDisk) { + diskStore.remove(id) + } + iterator.remove() + logInfo("Dropped block " + id) + } + reportBlockStatus(id) + } + } + } + def shouldCompress(blockId: String): Boolean = { if (blockId.startsWith("shuffle_")) { compressShuffle diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 397395a65b..af15663621 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -341,6 +341,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor throw new Exception("Self index for " + blockManagerId + " not found") } + // Note that this logic will select the same node multiple times if there aren't enough peers var index = selfIndex while (res.size < size) { index += 1 -- cgit v1.2.3 From 6fcd09f499dca66d255aa7196839156433aae442 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 29 Nov 2012 02:06:33 -0800 Subject: Added TimeStampedHashSet and used that to cleanup the list of registered RDD IDs in CacheTracker. --- core/src/main/scala/spark/CacheTracker.scala | 10 +++- .../main/scala/spark/util/TimeStampedHashMap.scala | 14 +++-- .../main/scala/spark/util/TimeStampedHashSet.scala | 66 ++++++++++++++++++++++ 3 files changed, 81 insertions(+), 9 deletions(-) create mode 100644 core/src/main/scala/spark/util/TimeStampedHashSet.scala (limited to 'core') diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 9888f061d9..cb54e12257 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -14,7 +14,7 @@ import scala.collection.mutable.HashSet import spark.storage.BlockManager import spark.storage.StorageLevel -import util.{MetadataCleaner, TimeStampedHashMap} +import util.{TimeStampedHashSet, MetadataCleaner, TimeStampedHashMap} private[spark] sealed trait CacheTrackerMessage @@ -39,7 +39,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] - private val metadataCleaner = new MetadataCleaner("CacheTracker", locs.cleanup) + private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.cleanup) private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) @@ -113,11 +113,15 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b actorSystem.actorFor(url) } - val registeredRddIds = new HashSet[Int] + // TODO: Consider removing this HashSet completely as locs CacheTrackerActor already + // keeps track of registered RDDs + val registeredRddIds = new TimeStampedHashSet[Int] // Remembers which splits are currently being loaded (on worker nodes) val loading = new HashSet[String] + val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.cleanup) + // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. def askTracker(message: Any): Any = { diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index 7a22b80a20..9bcc9245c0 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -1,7 +1,7 @@ package spark.util -import scala.collection.JavaConversions._ -import scala.collection.mutable.{HashMap, Map} +import scala.collection.JavaConversions +import scala.collection.mutable.Map import java.util.concurrent.ConcurrentHashMap /** @@ -20,7 +20,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() { def iterator: Iterator[(A, B)] = { val jIterator = internalMap.entrySet().iterator() - jIterator.map(kv => (kv.getKey, kv.getValue._1)) + JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1)) } override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { @@ -31,8 +31,10 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() { } override def - (key: A): Map[A, B] = { - internalMap.remove(key) - this + val newMap = new TimeStampedHashMap[A, B] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.remove(key) + newMap } override def += (kv: (A, B)): this.type = { @@ -56,7 +58,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() { } override def filter(p: ((A, B)) => Boolean): Map[A, B] = { - internalMap.map(kv => (kv._1, kv._2._1)).filter(p) + JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) } override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() diff --git a/core/src/main/scala/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/spark/util/TimeStampedHashSet.scala new file mode 100644 index 0000000000..539dd75844 --- /dev/null +++ b/core/src/main/scala/spark/util/TimeStampedHashSet.scala @@ -0,0 +1,66 @@ +package spark.util + +import scala.collection.mutable.Set +import scala.collection.JavaConversions +import java.util.concurrent.ConcurrentHashMap + + +class TimeStampedHashSet[A] extends Set[A] { + val internalMap = new ConcurrentHashMap[A, Long]() + + def contains(key: A): Boolean = { + internalMap.contains(key) + } + + def iterator: Iterator[A] = { + val jIterator = internalMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).map(_.getKey) + } + + override def + (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet += elem + newSet + } + + override def - (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet -= elem + newSet + } + + override def += (key: A): this.type = { + internalMap.put(key, currentTime) + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def empty: Set[A] = new TimeStampedHashSet[A]() + + override def size(): Int = internalMap.size() + + override def foreach[U](f: (A) => U): Unit = { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + f(iterator.next.getKey) + } + } + + def cleanup(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue < threshTime) { + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() +} -- cgit v1.2.3 From 477de94894b7d8eeed281d33c12bcb2269d117c7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 1 Dec 2012 13:15:06 -0800 Subject: Minor modifications. --- core/src/main/scala/spark/util/MetadataCleaner.scala | 7 ++++++- streaming/src/main/scala/spark/streaming/DStream.scala | 15 ++++++++++++++- .../scala/spark/streaming/ReducedWindowedDStream.scala | 4 ++-- .../src/main/scala/spark/streaming/StreamingContext.scala | 8 ++++++-- 4 files changed, 28 insertions(+), 6 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 71ac39864e..2541b26255 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -5,7 +5,7 @@ import java.util.{TimerTask, Timer} import spark.Logging class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt + val delaySeconds = MetadataCleaner.getDelaySeconds val periodSeconds = math.max(10, delaySeconds / 10) val timer = new Timer(name + " cleanup timer", true) val task = new TimerTask { @@ -30,3 +30,8 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging timer.cancel() } } + +object MetadataCleaner { + def getDelaySeconds = (System.getProperty("spark.cleaner.delay", "-100").toDouble * 60).toInt + def setDelaySeconds(delay: Long) { System.setProperty("spark.cleaner.delay", delay.toString) } +} diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 8efda2074d..28a3e2dfc7 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -146,6 +146,8 @@ extends Serializable with Logging { } protected[streaming] def validate() { + assert(rememberDuration != null, "Remember duration is set to null") + assert( !mustCheckpoint || checkpointInterval != null, "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " + @@ -180,13 +182,24 @@ extends Serializable with Logging { checkpointInterval + "). Please set it to higher than " + checkpointInterval + "." ) + val metadataCleanupDelay = System.getProperty("spark.cleanup.delay", "-1").toDouble + assert( + metadataCleanupDelay < 0 || rememberDuration < metadataCleanupDelay * 60 * 1000, + "It seems you are doing some DStream window operation or setting a checkpoint interval " + + "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + + "than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" + + "delay is set to " + metadataCleanupDelay + " minutes, which is not sufficient. Please set " + + "the Java property 'spark.cleanup.delay' to more than " + + math.ceil(rememberDuration.millis.toDouble / 60000.0).toInt + " minutes." + ) + dependencies.foreach(_.validate()) logInfo("Slide time = " + slideTime) logInfo("Storage level = " + storageLevel) logInfo("Checkpoint interval = " + checkpointInterval) logInfo("Remember duration = " + rememberDuration) - logInfo("Initialized " + this) + logInfo("Initialized and validated " + this) } protected[streaming] def setContext(s: StreamingContext) { diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index bb852cbcca..f63a9e0011 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -118,8 +118,8 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( if (seqOfValues(0).isEmpty) { // If previous window's reduce value does not exist, then at least new values should exist if (newValues.isEmpty) { - val info = "seqOfValues =\n" + seqOfValues.map(x => "[" + x.mkString(",") + "]").mkString("\n") - throw new Exception("Neither previous window has value for key, nor new values found\n" + info) + throw new Exception("Neither previous window has value for key, nor new values found. " + + "Are you sure your key class hashes consistently?") } // Reduce the new values newValues.reduce(reduceF) // return diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 63d8766749..9c19f6588d 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -17,6 +17,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path import java.util.UUID +import spark.util.MetadataCleaner /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -268,8 +269,11 @@ class StreamingContext private ( object StreamingContext { def createNewSparkContext(master: String, frameworkName: String): SparkContext = { - if (System.getProperty("spark.cleanup.delay", "-1").toDouble < 0) { - System.setProperty("spark.cleanup.delay", "60") + + // Set the default cleaner delay to an hour if not already set. + // This should be sufficient for even 1 second interval. + if (MetadataCleaner.getDelaySeconds < 0) { + MetadataCleaner.setDelaySeconds(60) } new SparkContext(master, frameworkName) } -- cgit v1.2.3 From b4dba55f78b0dfda728cf69c9c17e4863010d28d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 2 Dec 2012 02:03:05 +0000 Subject: Made RDD checkpoint not create a new thread. Fixed bug in detecting when spark.cleaner.delay is insufficient. --- core/src/main/scala/spark/RDD.scala | 31 +++++++--------------- .../main/scala/spark/util/TimeStampedHashMap.scala | 3 ++- .../src/main/scala/spark/streaming/DStream.scala | 9 ++++--- 3 files changed, 17 insertions(+), 26 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 8af6c9bd6a..fbfcfbd704 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -211,28 +211,17 @@ abstract class RDD[T: ClassManifest]( if (startCheckpoint) { val rdd = this - val env = SparkEnv.get - - // Spawn a new thread to do the checkpoint as it takes sometime to write the RDD to file - val th = new Thread() { - override def run() { - // Save the RDD to a file, create a new HadoopRDD from it, - // and change the dependencies from the original parents to the new RDD - SparkEnv.set(env) - rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString - rdd.saveAsObjectFile(checkpointFile) - rdd.synchronized { - rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size) - rdd.checkpointRDDSplits = rdd.checkpointRDD.splits - rdd.changeDependencies(rdd.checkpointRDD) - rdd.shouldCheckpoint = false - rdd.isCheckpointInProgress = false - rdd.isCheckpointed = true - println("Done checkpointing RDD " + rdd.id + ", " + rdd) - } - } + rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString + rdd.saveAsObjectFile(checkpointFile) + rdd.synchronized { + rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size) + rdd.checkpointRDDSplits = rdd.checkpointRDD.splits + rdd.changeDependencies(rdd.checkpointRDD) + rdd.shouldCheckpoint = false + rdd.isCheckpointInProgress = false + rdd.isCheckpointed = true + println("Done checkpointing RDD " + rdd.id + ", " + rdd + ", created RDD " + rdd.checkpointRDD.id + ", " + rdd.checkpointRDD) } - th.start() } else { // Recursively call doCheckpoint() to perform checkpointing on parent RDD if they are marked dependencies.foreach(_.rdd.doCheckpoint()) diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index 9bcc9245c0..52f03784db 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -10,7 +10,7 @@ import java.util.concurrent.ConcurrentHashMap * threshold time can them be removed using the cleanup method. This is intended to be a drop-in * replacement of scala.collection.mutable.HashMap. */ -class TimeStampedHashMap[A, B] extends Map[A, B]() { +class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { val internalMap = new ConcurrentHashMap[A, (B, Long)]() def get(key: A): Option[B] = { @@ -79,6 +79,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() { while(iterator.hasNext) { val entry = iterator.next() if (entry.getValue._2 < threshTime) { + logDebug("Removing key " + entry.getKey) iterator.remove() } } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 28a3e2dfc7..d2e9de110e 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -182,14 +182,15 @@ extends Serializable with Logging { checkpointInterval + "). Please set it to higher than " + checkpointInterval + "." ) - val metadataCleanupDelay = System.getProperty("spark.cleanup.delay", "-1").toDouble + val metadataCleanerDelay = spark.util.MetadataCleaner.getDelaySeconds + logInfo("metadataCleanupDelay = " + metadataCleanerDelay) assert( - metadataCleanupDelay < 0 || rememberDuration < metadataCleanupDelay * 60 * 1000, + metadataCleanerDelay < 0 || rememberDuration < metadataCleanerDelay * 1000, "It seems you are doing some DStream window operation or setting a checkpoint interval " + "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + "than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" + - "delay is set to " + metadataCleanupDelay + " minutes, which is not sufficient. Please set " + - "the Java property 'spark.cleanup.delay' to more than " + + "delay is set to " + (metadataCleanerDelay / 60.0) + " minutes, which is not sufficient. Please set " + + "the Java property 'spark.cleaner.delay' to more than " + math.ceil(rememberDuration.millis.toDouble / 60000.0).toInt + " minutes." ) -- cgit v1.2.3 From a69a82be2682148f5d1ebbdede15a47c90eea73d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 3 Dec 2012 22:37:31 -0800 Subject: Added metadata cleaner to HttpBroadcast to clean up old broacast files. --- .../main/scala/spark/broadcast/HttpBroadcast.scala | 24 ++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 7eb4ddb74f..fef264aab1 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -11,6 +11,7 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import spark._ import spark.storage.StorageLevel +import util.{MetadataCleaner, TimeStampedHashSet} private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { @@ -64,6 +65,10 @@ private object HttpBroadcast extends Logging { private var serverUri: String = null private var server: HttpServer = null + private val files = new TimeStampedHashSet[String] + private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup) + + def initialize(isMaster: Boolean) { synchronized { if (!initialized) { @@ -85,6 +90,7 @@ private object HttpBroadcast extends Logging { server = null } initialized = false + cleaner.cancel() } } @@ -108,6 +114,7 @@ private object HttpBroadcast extends Logging { val serOut = ser.serializeStream(out) serOut.writeObject(value) serOut.close() + files += file.getAbsolutePath } def read[T](id: Long): T = { @@ -123,4 +130,21 @@ private object HttpBroadcast extends Logging { serIn.close() obj } + + def cleanup(cleanupTime: Long) { + val iterator = files.internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val (file, time) = (entry.getKey, entry.getValue) + if (time < cleanupTime) { + try { + iterator.remove() + new File(file.toString).delete() + logInfo("Deleted broadcast file '" + file + "'") + } catch { + case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) + } + } + } + } } -- cgit v1.2.3 From 21a08529768a5073bc5c15b6c2642ceef2acd0d5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 4 Dec 2012 22:10:25 -0800 Subject: Refactored RDD checkpointing to minimize extra fields in RDD class. --- core/src/main/scala/spark/RDD.scala | 149 ++++++++------------- core/src/main/scala/spark/RDDCheckpointData.scala | 68 ++++++++++ core/src/main/scala/spark/rdd/BlockRDD.scala | 9 +- core/src/main/scala/spark/rdd/CartesianRDD.scala | 10 +- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 2 +- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 2 +- core/src/main/scala/spark/rdd/HadoopRDD.scala | 2 - core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 2 - core/src/main/scala/spark/rdd/UnionRDD.scala | 12 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 2 +- core/src/test/scala/spark/CheckpointSuite.scala | 73 +--------- .../src/main/scala/spark/streaming/DStream.scala | 7 +- 12 files changed, 144 insertions(+), 194 deletions(-) create mode 100644 core/src/main/scala/spark/RDDCheckpointData.scala (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index fbfcfbd704..e9bd131e61 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -99,13 +99,7 @@ abstract class RDD[T: ClassManifest]( val partitioner: Option[Partitioner] = None /** Optionally overridden by subclasses to specify placement preferences. */ - def preferredLocations(split: Split): Seq[String] = { - if (isCheckpointed) { - checkpointRDD.preferredLocations(split) - } else { - Nil - } - } + def preferredLocations(split: Split): Seq[String] = Nil /** The [[spark.SparkContext]] that this RDD was created on. */ def context = sc @@ -118,6 +112,8 @@ abstract class RDD[T: ClassManifest]( // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE + protected[spark] val checkpointData = new RDDCheckpointData(this) + /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassManifest] = { dependencies.head.rdd.asInstanceOf[RDD[U]] @@ -126,17 +122,6 @@ abstract class RDD[T: ClassManifest]( /** Returns the `i` th parent RDD */ protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] - // Variables relating to checkpointing - protected val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD - - protected var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing - protected var isCheckpointInProgress = false // set to true when checkpointing is in progress - protected[spark] var isCheckpointed = false // set to true after checkpointing is completed - - protected[spark] var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed - protected var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file - protected var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD - // Methods available on all RDDs: /** @@ -162,83 +147,14 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - /** - * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` - * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. - * This is used to truncate very long lineages. In the current implementation, Spark will save - * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. - * Hence, it is strongly recommended to use checkpoint() on RDDs when - * (i) Checkpoint() is called before the any job has been executed on this RDD. - * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will - * require recomputation. - */ - protected[spark] def checkpoint() { - synchronized { - if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) { - // do nothing - } else if (isCheckpointable) { - if (sc.checkpointDir == null) { - throw new Exception("Checkpoint directory has not been set in the SparkContext.") - } - shouldCheckpoint = true - } else { - throw new Exception(this + " cannot be checkpointed") - } - } - } - - def getCheckpointData(): Any = { - synchronized { - checkpointFile - } - } - - /** - * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job - * using this RDD has completed (therefore the RDD has been materialized and - * potentially stored in memory). In case this RDD is not marked for checkpointing, - * doCheckpoint() is called recursively on the parent RDDs. - */ - private[spark] def doCheckpoint() { - val startCheckpoint = synchronized { - if (isCheckpointable && shouldCheckpoint && !isCheckpointInProgress) { - isCheckpointInProgress = true - true - } else { - false - } - } - - if (startCheckpoint) { - val rdd = this - rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString - rdd.saveAsObjectFile(checkpointFile) - rdd.synchronized { - rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size) - rdd.checkpointRDDSplits = rdd.checkpointRDD.splits - rdd.changeDependencies(rdd.checkpointRDD) - rdd.shouldCheckpoint = false - rdd.isCheckpointInProgress = false - rdd.isCheckpointed = true - println("Done checkpointing RDD " + rdd.id + ", " + rdd + ", created RDD " + rdd.checkpointRDD.id + ", " + rdd.checkpointRDD) - } + def getPreferredLocations(split: Split) = { + if (isCheckpointed) { + checkpointData.preferredLocations(split) } else { - // Recursively call doCheckpoint() to perform checkpointing on parent RDD if they are marked - dependencies.foreach(_.rdd.doCheckpoint()) + preferredLocations(split) } } - /** - * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]] - * (`newRDD`) created from the checkpoint file. This method must ensure that all references - * to the original parent RDDs must be removed to enable the parent RDDs to be garbage - * collected. Subclasses of RDD may override this method for implementing their own changing - * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. - */ - protected def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD)) - } - /** * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. * This should ''not'' be called by users directly, but is available for implementors of custom @@ -247,7 +163,7 @@ abstract class RDD[T: ClassManifest]( final def iterator(split: Split): Iterator[T] = { if (isCheckpointed) { // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original - checkpointRDD.iterator(checkpointRDDSplits(split.index)) + checkpointData.iterator(split.index) } else if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { @@ -589,6 +505,55 @@ abstract class RDD[T: ClassManifest]( sc.runJob(this, (iter: Iterator[T]) => iter.toArray) } + /** + * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` + * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. + * This is used to truncate very long lineages. In the current implementation, Spark will save + * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. + * Hence, it is strongly recommended to use checkpoint() on RDDs when + * (i) checkpoint() is called before the any job has been executed on this RDD. + * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will + * require recomputation. + */ + def checkpoint() { + checkpointData.markForCheckpoint() + } + + /** + * Return whether this RDD has been checkpointed or not + */ + def isCheckpointed(): Boolean = { + checkpointData.isCheckpointed() + } + + /** + * Gets the name of the file to which this RDD was checkpointed + */ + def getCheckpointFile(): Option[String] = { + checkpointData.getCheckpointFile() + } + + /** + * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler + * after a job using this RDD has completed (therefore the RDD has been materialized and + * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. + */ + protected[spark] def doCheckpoint() { + checkpointData.doCheckpoint() + dependencies.foreach(_.rdd.doCheckpoint()) + } + + /** + * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]] + * (`newRDD`) created from the checkpoint file. This method must ensure that all references + * to the original parent RDDs must be removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own changing + * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + */ + protected[spark] def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD)) + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { synchronized { diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala new file mode 100644 index 0000000000..eb4482acee --- /dev/null +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -0,0 +1,68 @@ +package spark + +import org.apache.hadoop.fs.Path + + + +private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) +extends Serializable { + + class CheckpointState extends Serializable { + var state = 0 + + def mark() { if (state == 0) state = 1 } + def start() { assert(state == 1); state = 2 } + def finish() { assert(state == 2); state = 3 } + + def isMarked() = { state == 1 } + def isInProgress = { state == 2 } + def isCheckpointed = { state == 3 } + } + + val cpState = new CheckpointState() + var cpFile: Option[String] = None + var cpRDD: Option[RDD[T]] = None + var cpRDDSplits: Seq[Split] = Nil + + def markForCheckpoint() = { + rdd.synchronized { cpState.mark() } + } + + def isCheckpointed() = { + rdd.synchronized { cpState.isCheckpointed } + } + + def getCheckpointFile() = { + rdd.synchronized { cpFile } + } + + def doCheckpoint() { + rdd.synchronized { + if (cpState.isMarked && !cpState.isInProgress) { + cpState.start() + } else { + return + } + } + + val file = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString + rdd.saveAsObjectFile(file) + val newRDD = rdd.context.objectFile[T](file, rdd.splits.size) + + rdd.synchronized { + rdd.changeDependencies(newRDD) + cpFile = Some(file) + cpRDD = Some(newRDD) + cpRDDSplits = newRDD.splits + cpState.finish() + } + } + + def preferredLocations(split: Split) = { + cpRDD.get.preferredLocations(split) + } + + def iterator(splitIndex: Int): Iterator[T] = { + cpRDD.get.iterator(cpRDDSplits(splitIndex)) + } +} diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index f4c3f99011..590f9eb738 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -41,12 +41,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St } } - override def preferredLocations(split: Split) = { - if (isCheckpointed) { - checkpointRDD.preferredLocations(split) - } else { - locations_(split.asInstanceOf[BlockRDDSplit].blockId) - } - } + override def preferredLocations(split: Split) = + locations_(split.asInstanceOf[BlockRDDSplit].blockId) } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 458ad38d55..9bfc3f8ca3 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -32,12 +32,8 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( override def splits = splits_ override def preferredLocations(split: Split) = { - if (isCheckpointed) { - checkpointRDD.preferredLocations(split) - } else { - val currSplit = split.asInstanceOf[CartesianSplit] - rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) - } + val currSplit = split.asInstanceOf[CartesianSplit] + rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) } override def compute(split: Split) = { @@ -56,7 +52,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( override def dependencies = deps_ - override protected def changeDependencies(newRDD: RDD[_]) { + override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) splits_ = newRDD.splits rdd1 = null diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 94ef1b56e8..adfecea966 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -112,7 +112,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) map.iterator } - override protected def changeDependencies(newRDD: RDD[_]) { + override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) splits_ = newRDD.splits rdds = null diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 5b5f72ddeb..90c3b8bfd8 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -48,7 +48,7 @@ class CoalescedRDD[T: ClassManifest]( override def dependencies = deps_ - override protected def changeDependencies(newRDD: RDD[_]) { + override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD)) splits_ = newRDD.splits prev = null diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index 19ed56d9c0..a12531ea89 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -115,6 +115,4 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } - - override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index 2875abb2db..c12df5839e 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -93,6 +93,4 @@ class NewHadoopRDD[K, V]( val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } - - override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 643a174160..30eb8483b6 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -49,15 +49,11 @@ class UnionRDD[T: ClassManifest]( override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() - override def preferredLocations(s: Split): Seq[String] = { - if (isCheckpointed) { - checkpointRDD.preferredLocations(s) - } else { - s.asInstanceOf[UnionSplit[T]].preferredLocations() - } - } + override def preferredLocations(s: Split): Seq[String] = + s.asInstanceOf[UnionSplit[T]].preferredLocations() + - override protected def changeDependencies(newRDD: RDD[_]) { + override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD)) splits_ = newRDD.splits rdds = null diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 4b2570fa2b..33d35b35d1 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -575,7 +575,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those - val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList + val rddPrefs = rdd.getPreferredLocations(rdd.splits(partition)).toList if (rddPrefs != Nil) { return rddPrefs } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 8622ce92aa..2cafef444c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -41,7 +41,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(parCollection.dependencies === Nil) val result = parCollection.collect() sleep(parCollection) // slightly extra time as loading classes for the first can take some time - assert(sc.objectFile[Int](parCollection.checkpointFile).collect() === result) + assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.collect() === result) } @@ -54,7 +54,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { blockRDD.checkpoint() val result = blockRDD.collect() sleep(blockRDD) - assert(sc.objectFile[String](blockRDD.checkpointFile).collect() === result) + assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.collect() === result) } @@ -122,35 +122,6 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { "CoGroupedSplits still holds on to the splits of its parent RDDs") } - /** - * This test forces two ResultTasks of the same job to be launched before and after - * the checkpointing of job's RDD is completed. - */ - test("Threading - ResultTasks") { - val op1 = (parCollection: RDD[Int]) => { - parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) - } - val op2 = (firstRDD: RDD[(Int, Int)]) => { - firstRDD.map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) - } - testThreading(op1, op2) - } - - /** - * This test forces two ShuffleMapTasks of the same job to be launched before and after - * the checkpointing of job's RDD is completed. - */ - test("Threading - ShuffleMapTasks") { - val op1 = (parCollection: RDD[Int]) => { - parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) - } - val op2 = (firstRDD: RDD[(Int, Int)]) => { - firstRDD.groupByKey(2).map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) - } - testThreading(op1, op2) - } - - def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { val parCollection = sc.makeRDD(1 to 4, 4) val operatedRDD = op(parCollection) @@ -159,49 +130,11 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val result = operatedRDD.collect() sleep(operatedRDD) //println(parentRDD + ", " + operatedRDD.dependencies.head.rdd ) - assert(sc.objectFile[U](operatedRDD.checkpointFile).collect() === result) + assert(sc.objectFile[U](operatedRDD.getCheckpointFile.get).collect() === result) assert(operatedRDD.dependencies.head.rdd != parentRDD) assert(operatedRDD.collect() === result) } - def testThreading[U: ClassManifest, V: ClassManifest](op1: (RDD[Int]) => RDD[U], op2: (RDD[U]) => RDD[V]) { - - val parCollection = sc.makeRDD(1 to 2, 2) - - // This is the RDD that is to be checkpointed - val firstRDD = op1(parCollection) - val parentRDD = firstRDD.dependencies.head.rdd - firstRDD.checkpoint() - - // This the RDD that uses firstRDD. This is designed to launch a - // ShuffleMapTask that uses firstRDD. - val secondRDD = op2(firstRDD) - - // Starting first job, to initiate the checkpointing - logInfo("\nLaunching 1st job to initiate checkpointing\n") - firstRDD.collect() - - // Checkpointing has started but not completed yet - Thread.sleep(100) - assert(firstRDD.dependencies.head.rdd === parentRDD) - - // Starting second job; first task of this job will be - // launched _before_ firstRDD is marked as checkpointed - // and the second task will be launched _after_ firstRDD - // is marked as checkpointed - logInfo("\nLaunching 2nd job that is designed to launch tasks " + - "before and after checkpointing is complete\n") - val result = secondRDD.collect() - - // Check whether firstRDD has been successfully checkpointed - assert(firstRDD.dependencies.head.rdd != parentRDD) - - logInfo("\nRecomputing 2nd job to verify the results of the previous computation\n") - // Check whether the result in the previous job was correct or not - val correctResult = secondRDD.collect() - assert(result === correctResult) - } - def sleep(rdd: RDD[_]) { val startTime = System.currentTimeMillis() val maxWaitTime = 5000 diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index d2e9de110e..d290c5927e 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -325,8 +325,9 @@ extends Serializable with Logging { logInfo("Updating checkpoint data for time " + currentTime) // Get the checkpointed RDDs from the generated RDDs - val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null) - .map(x => (x._1, x._2.getCheckpointData())) + + val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointFile.isDefined) + .map(x => (x._1, x._2.getCheckpointFile.get)) // Make a copy of the existing checkpoint data val oldCheckpointData = checkpointData.clone() @@ -373,7 +374,7 @@ extends Serializable with Logging { logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'") val rdd = ssc.sc.objectFile[T](data.toString) // Set the checkpoint file name to identify this RDD as a checkpointed RDD by updateCheckpointData() - rdd.checkpointFile = data.toString + rdd.checkpointData.cpFile = Some(data.toString) generatedRDDs += ((time, rdd)) } } -- cgit v1.2.3 From 1f3a75ae9e518c003d84fa38a54583ecd841ffdc Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 7 Dec 2012 13:45:52 -0800 Subject: Modified checkpoint testsuite to more comprehensively test checkpointing of various RDDs. Fixed checkpoint bug (splits referring to parent RDDs or parent splits) in UnionRDD and CoalescedRDD. Fixed bug in testing ShuffledRDD. Removed unnecessary and useless map-side combining step for narrow dependencies in CoGroupedRDD. Removed unncessary WeakReference stuff from many other RDDs. --- core/src/main/scala/spark/rdd/CartesianRDD.scala | 1 - core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 9 +- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 25 +- core/src/main/scala/spark/rdd/FilteredRDD.scala | 6 +- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 6 +- core/src/main/scala/spark/rdd/GlommedRDD.scala | 6 +- .../main/scala/spark/rdd/MapPartitionsRDD.scala | 6 +- .../spark/rdd/MapPartitionsWithSplitRDD.scala | 6 +- core/src/main/scala/spark/rdd/MappedRDD.scala | 6 +- core/src/main/scala/spark/rdd/PipedRDD.scala | 10 +- core/src/main/scala/spark/rdd/SampledRDD.scala | 12 +- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 16 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 19 +- core/src/test/scala/spark/CheckpointSuite.scala | 267 +++++++++++++++++---- 14 files changed, 285 insertions(+), 110 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 9bfc3f8ca3..1d753a5168 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,7 +1,6 @@ package spark.rdd import spark._ -import java.lang.ref.WeakReference private[spark] class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index adfecea966..57d472666b 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -17,6 +17,7 @@ import java.io.{ObjectOutputStream, IOException} private[spark] sealed trait CoGroupSplitDep extends Serializable private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, var split: Split = null) extends CoGroupSplitDep { + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { rdd.synchronized { @@ -50,12 +51,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) var deps_ = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { - val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) - if (mapSideCombinedRDD.partitioner == Some(part)) { - logInfo("Adding one-to-one dependency with " + mapSideCombinedRDD) - deps += new OneToOneDependency(mapSideCombinedRDD) + if (rdd.partitioner == Some(part)) { + logInfo("Adding one-to-one dependency with " + rdd) + deps += new OneToOneDependency(rdd) } else { logInfo("Adding shuffle dependency with " + rdd) + val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part) } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 90c3b8bfd8..0b4499e2eb 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -1,9 +1,24 @@ package spark.rdd import spark._ -import java.lang.ref.WeakReference +import java.io.{ObjectOutputStream, IOException} -private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split +private[spark] case class CoalescedRDDSplit( + index: Int, + @transient rdd: RDD[_], + parentsIndices: Array[Int] + ) extends Split { + var parents: Seq[Split] = parentsIndices.map(rdd.splits(_)) + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + rdd.synchronized { + // Update the reference to parent split at the time of task serialization + parents = parentsIndices.map(rdd.splits(_)) + oos.defaultWriteObject() + } + } +} /** * Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of @@ -21,12 +36,12 @@ class CoalescedRDD[T: ClassManifest]( @transient var splits_ : Array[Split] = { val prevSplits = prev.splits if (prevSplits.length < maxPartitions) { - prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) } + prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) } } else { (0 until maxPartitions).map { i => val rangeStart = (i * prevSplits.length) / maxPartitions val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions - new CoalescedRDDSplit(i, prevSplits.slice(rangeStart, rangeEnd)) + new CoalescedRDDSplit(i, prev, (rangeStart until rangeEnd).toArray) }.toArray } } @@ -42,7 +57,7 @@ class CoalescedRDD[T: ClassManifest]( var deps_ : List[Dependency[_]] = List( new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = - splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index) + splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices } ) diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index 1370cf6faf..02f2e7c246 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -1,15 +1,13 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] class FilteredRDD[T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: T => Boolean) - extends RDD[T](prev.get) { + extends RDD[T](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).filter(f) diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 6b2cc67568..cdc8ecdcfe 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -1,15 +1,13 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: T => TraversableOnce[U]) - extends RDD[U](prev.get) { + extends RDD[U](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).flatMap(f) diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index 0f0b6ab0ff..df6f61c69d 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -1,13 +1,11 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] -class GlommedRDD[T: ClassManifest](prev: WeakReference[RDD[T]]) - extends RDD[Array[T]](prev.get) { +class GlommedRDD[T: ClassManifest](prev: RDD[T]) + extends RDD[Array[T]](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index b04f56cfcc..23b9fb023b 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -1,16 +1,14 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false) - extends RDD[U](prev.get) { + extends RDD[U](prev) { override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index 7a4b6ffb03..41955c1d7a 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -1,9 +1,7 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference /** * A variant of the MapPartitionsRDD that passes the split index into the @@ -12,9 +10,9 @@ import java.lang.ref.WeakReference */ private[spark] class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: (Int, Iterator[T]) => Iterator[U]) - extends RDD[U](prev.get) { + extends RDD[U](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = f(split.index, firstParent[T].iterator(split)) diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 8fa1872e0a..6f8cb21fd3 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -1,15 +1,13 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: T => U) - extends RDD[U](prev.get) { + extends RDD[U](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).map(f) diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index d9293a9d1a..d2047375ea 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -8,11 +8,9 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source -import spark.OneToOneDependency import spark.RDD import spark.SparkEnv import spark.Split -import java.lang.ref.WeakReference /** @@ -20,16 +18,16 @@ import java.lang.ref.WeakReference * (printing them one per line) and returns the output as a collection of strings. */ class PipedRDD[T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], command: Seq[String], envVars: Map[String, String]) - extends RDD[String](prev.get) { + extends RDD[String](prev) { - def this(prev: WeakReference[RDD[T]], command: Seq[String]) = this(prev, command, Map()) + def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(prev: WeakReference[RDD[T]], command: String) = this(prev, PipedRDD.tokenize(command)) + def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) override def splits = firstParent[T].splits diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index f273f257f8..c622e14a66 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -7,7 +7,6 @@ import cern.jet.random.engine.DRand import spark.RDD import spark.OneToOneDependency import spark.Split -import java.lang.ref.WeakReference private[spark] class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { @@ -15,14 +14,14 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali } class SampledRDD[T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], withReplacement: Boolean, frac: Double, seed: Int) - extends RDD[T](prev.get) { + extends RDD[T](prev) { @transient - val splits_ = { + var splits_ : Array[Split] = { val rg = new Random(seed) firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } @@ -51,4 +50,9 @@ class SampledRDD[T: ClassManifest]( firstParent[T].iterator(split.prev).filter(x => (rand.nextDouble <= frac)) } } + + override def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 31774585f4..a9dd3f35ed 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -1,11 +1,7 @@ package spark.rdd -import spark.Partitioner -import spark.RDD -import spark.ShuffleDependency -import spark.SparkEnv -import spark.Split -import java.lang.ref.WeakReference +import spark._ +import scala.Some private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { override val index = idx @@ -14,15 +10,15 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { /** * The resulting RDD from a shuffle (e.g. repartitioning of data). - * @param parent the parent RDD. + * @param prev the parent RDD. * @param part the partitioner used to partition the RDD * @tparam K the key class. * @tparam V the value class. */ class ShuffledRDD[K, V]( - @transient prev: WeakReference[RDD[(K, V)]], + prev: RDD[(K, V)], part: Partitioner) - extends RDD[(K, V)](prev.get.context, List(new ShuffleDependency(prev.get, part))) { + extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) { override val partitioner = Some(part) @@ -37,7 +33,7 @@ class ShuffledRDD[K, V]( } override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = Nil + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) splits_ = newRDD.splits } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 30eb8483b6..a5948dd1f1 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -3,18 +3,28 @@ package spark.rdd import scala.collection.mutable.ArrayBuffer import spark._ -import java.lang.ref.WeakReference +import java.io.{ObjectOutputStream, IOException} private[spark] class UnionSplit[T: ClassManifest]( - idx: Int, + idx: Int, rdd: RDD[T], - split: Split) + splitIndex: Int, + var split: Split = null) extends Split with Serializable { def iterator() = rdd.iterator(split) def preferredLocations() = rdd.preferredLocations(split) override val index: Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + rdd.synchronized { + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() + } + } } class UnionRDD[T: ClassManifest]( @@ -27,7 +37,7 @@ class UnionRDD[T: ClassManifest]( val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { - array(pos) = new UnionSplit(pos, rdd, split) + array(pos) = new UnionSplit(pos, rdd, split.index) pos += 1 } array @@ -52,7 +62,6 @@ class UnionRDD[T: ClassManifest]( override def preferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() - override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD)) splits_ = newRDD.splits diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 2cafef444c..51bd59e2b1 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -2,17 +2,16 @@ package spark import org.scalatest.{BeforeAndAfter, FunSuite} import java.io.File -import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD} +import spark.rdd._ import spark.SparkContext._ import storage.StorageLevel -import java.util.concurrent.Semaphore -import collection.mutable.ArrayBuffer class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { initLogging() var sc: SparkContext = _ var checkpointDir: File = _ + val partitioner = new HashPartitioner(2) before { checkpointDir = File.createTempFile("temp", "") @@ -40,7 +39,6 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { parCollection.checkpoint() assert(parCollection.dependencies === Nil) val result = parCollection.collect() - sleep(parCollection) // slightly extra time as loading classes for the first can take some time assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.collect() === result) @@ -53,7 +51,6 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val blockRDD = new BlockRDD[String](sc, Array(blockId)) blockRDD.checkpoint() val result = blockRDD.collect() - sleep(blockRDD) assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.collect() === result) @@ -68,79 +65,247 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { testCheckpointing(_.mapPartitions(_.map(_.toString))) testCheckpointing(r => new MapPartitionsWithSplitRDD(r, (i: Int, iter: Iterator[Int]) => iter.map(_.toString) )) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), 1000) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x), 1000) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) testCheckpointing(_.pipe(Seq("cat"))) } test("ShuffledRDD") { - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _)) + // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD + testCheckpointing(rdd => { + new ShuffledRDD(rdd.map(x => (x % 2, 1)), partitioner) + }) } test("UnionRDD") { - testCheckpointing(_.union(sc.makeRDD(5 to 6, 4))) + def otherRDD = sc.makeRDD(1 to 10, 4) + testCheckpointing(_.union(otherRDD), false, true) + testParentCheckpointing(_.union(otherRDD), false, true) } test("CartesianRDD") { - testCheckpointing(_.cartesian(sc.makeRDD(5 to 6, 4)), 1000) + def otherRDD = sc.makeRDD(1 to 10, 4) + testCheckpointing(_.cartesian(otherRDD)) + testParentCheckpointing(_.cartesian(otherRDD), true, false) } test("CoalescedRDD") { testCheckpointing(new CoalescedRDD(_, 2)) + + // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed + // Current implementation of CoalescedRDDSplit has transient reference to parent RDD, + // so does not serialize the RDD (not need to check its size). + testParentCheckpointing(new CoalescedRDD(_, 2), true, false) + + // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after + // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. + // Note that this test is very specific to the current implementation of CoalescedRDDSplits + val ones = sc.makeRDD(1 to 100, 10).map(x => x) + ones.checkpoint // checkpoint that MappedRDD + val coalesced = new CoalescedRDD(ones, 2) + val splitBeforeCheckpoint = + serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit]) + coalesced.count() // do the checkpointing + val splitAfterCheckpoint = + serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit]) + assert( + splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head, + "CoalescedRDDSplit.parents not updated after parent RDD checkpointed" + ) } test("CoGroupedRDD") { - val rdd2 = sc.makeRDD(5 to 6, 4).map(x => (x % 2, 1)) - testCheckpointing(rdd1 => rdd1.map(x => (x % 2, 1)).cogroup(rdd2)) - testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2)) + // Test serialized size + // RDD with long lineage of one-to-one dependencies through cogroup transformations + val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD() + testCheckpointing(rdd1 => { + CheckpointSuite.cogroup(longLineageRDD1, rdd1.map(x => (x % 2, 1)), partitioner) + }, false, true) - // Special test to make sure that the CoGroupSplit of CoGroupedRDD do not - // hold on to the splits of its parent RDDs, as the splits of parent RDDs - // may change while checkpointing. Rather the splits of parent RDDs must - // be fetched at the time of serialization to ensure the latest splits to - // be sent along with the task. + val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD() + testParentCheckpointing(rdd1 => { + CheckpointSuite.cogroup(longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner) + }, false, true) + } - val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) + /** + * Test checkpointing of the final RDD generated by the given operation. By default, + * this method tests whether the size of serialized RDD has reduced after checkpointing or not. + * It can also test whether the size of serialized RDD splits has reduced after checkpointing or + * not, but this is not done by default as usually the splits do not refer to any RDD and + * therefore never store the lineage. + */ + def testCheckpointing[U: ClassManifest]( + op: (RDD[Int]) => RDD[U], + testRDDSize: Boolean = true, + testRDDSplitSize: Boolean = false + ) { + // Generate the final RDD using given RDD operation + val baseRDD = generateLongLineageRDD + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.headOption.orNull + val rddType = operatedRDD.getClass.getSimpleName - val ones = sc.parallelize(1 to 100, 1).map(x => (x,1)) - val reduced = ones.reduceByKey(_ + _) - val seqOfCogrouped = new ArrayBuffer[RDD[(Int, Int)]]() - seqOfCogrouped += reduced.cogroup(ones).mapValues[Int](add) - for(i <- 1 to 10) { - seqOfCogrouped += seqOfCogrouped.last.cogroup(ones).mapValues(add) - } - val finalCogrouped = seqOfCogrouped.last - val intermediateCogrouped = seqOfCogrouped(5) - - val bytesBeforeCheckpoint = Utils.serialize(finalCogrouped.splits) - intermediateCogrouped.checkpoint() - finalCogrouped.count() - sleep(intermediateCogrouped) - val bytesAfterCheckpoint = Utils.serialize(finalCogrouped.splits) - println("Before = " + bytesBeforeCheckpoint.size + ", after = " + bytesAfterCheckpoint.size) - assert(bytesAfterCheckpoint.size < bytesBeforeCheckpoint.size, - "CoGroupedSplits still holds on to the splits of its parent RDDs") - } - - def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { - val parCollection = sc.makeRDD(1 to 4, 4) - val operatedRDD = op(parCollection) + // Find serialized sizes before and after the checkpoint + val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) operatedRDD.checkpoint() - val parentRDD = operatedRDD.dependencies.head.rdd val result = operatedRDD.collect() - sleep(operatedRDD) - //println(parentRDD + ", " + operatedRDD.dependencies.head.rdd ) + val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + + // Test whether the checkpoint file has been created assert(sc.objectFile[U](operatedRDD.getCheckpointFile.get).collect() === result) + + // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD) + + // Test whether the splits have been changed to the new Hadoop splits + assert(operatedRDD.splits.toList === operatedRDD.checkpointData.cpRDDSplits.toList) + + // Test whether the data in the checkpointed RDD is same as original assert(operatedRDD.collect() === result) + + // Test whether serialized size of the RDD has reduced. If the RDD + // does not have any dependency to another RDD (e.g., ParallelCollection, + // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing. + if (testRDDSize) { + println("Size of " + rddType + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing " + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + // Test whether serialized size of the splits has reduced. If the splits + // do not have any non-transient reference to another RDD or another RDD's splits, it + // does not refer to a lineage and therefore may not reduce in size after checkpointing. + // However, if the original splits before checkpointing do refer to a parent RDD, the splits + // must be forgotten after checkpointing (to remove all reference to parent RDDs) and + // replaced with the HadoopSplits of the checkpointed RDD. + if (testRDDSplitSize) { + println("Size of " + rddType + " splits " + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]") + assert( + splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, + "Size of " + rddType + " splits did not reduce after checkpointing " + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" + ) + } } - def sleep(rdd: RDD[_]) { - val startTime = System.currentTimeMillis() - val maxWaitTime = 5000 - while(rdd.isCheckpointed == false && System.currentTimeMillis() < startTime + maxWaitTime) { - Thread.sleep(50) + /** + * Test whether checkpointing of the parent of the generated RDD also + * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent + * RDDs splits. So even if the parent RDD is checkpointed and its splits changed, + * this RDD will remember the splits and therefore potentially the whole lineage. + */ + def testParentCheckpointing[U: ClassManifest]( + op: (RDD[Int]) => RDD[U], + testRDDSize: Boolean, + testRDDSplitSize: Boolean + ) { + // Generate the final RDD using given RDD operation + val baseRDD = generateLongLineageRDD + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.head.rdd + val rddType = operatedRDD.getClass.getSimpleName + val parentRDDType = parentRDD.getClass.getSimpleName + + // Find serialized sizes before and after the checkpoint + val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one + val result = operatedRDD.collect() + val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + + // Test whether the data in the checkpointed RDD is same as original + assert(operatedRDD.collect() === result) + + // Test whether serialized size of the RDD has reduced because of its parent being + // checkpointed. If this RDD or its parent RDD do not have any dependency + // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may + // not reduce in size after checkpointing. + if (testRDDSize) { + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + // Test whether serialized size of the splits has reduced because of its parent being + // checkpointed. If the splits do not have any non-transient reference to another RDD + // or another RDD's splits, it does not refer to a lineage and therefore may not reduce + // in size after checkpointing. However, if the splits do refer to the *splits* of a parent + // RDD, then these splits must update reference to the parent RDD splits as the parent RDD's + // splits must have changed after checkpointing. + if (testRDDSplitSize) { + assert( + splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, + "Size of " + rddType + " splits did not reduce after checkpointing parent " + parentRDDType + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" + ) } - assert(rdd.isCheckpointed === true, "Waiting for checkpoint to complete took more than " + maxWaitTime + " ms") + + } + + /** + * Generate an RDD with a long lineage of one-to-one dependencies. + */ + def generateLongLineageRDD(): RDD[Int] = { + var rdd = sc.makeRDD(1 to 100, 4) + for (i <- 1 to 20) { + rdd = rdd.map(x => x) + } + rdd + } + + /** + * Generate an RDD with a long lineage specifically for CoGroupedRDD. + * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage + * and narrow dependency with this RDD. This method generate such an RDD by a sequence + * of cogroups and mapValues which creates a long lineage of narrow dependencies. + */ + def generateLongLineageRDDForCoGroupedRDD() = { + val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) + + def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) + + var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones) + for(i <- 1 to 10) { + cogrouped = cogrouped.mapValues(add).cogroup(ones) + } + cogrouped.mapValues(add) + } + + /** + * Get serialized sizes of the RDD and its splits + */ + def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { + (Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size) + } + + /** + * Serialize and deserialize an object. This is useful to verify the objects + * contents after deserialization (e.g., the contents of an RDD split after + * it is sent to a slave along with a task) + */ + def serializeDeserialize[T](obj: T): T = { + val bytes = Utils.serialize(obj) + Utils.deserialize[T](bytes) } } + + +object CheckpointSuite { + // This is a custom cogroup function that does not use mapValues like + // the PairRDDFunctions.cogroup() + def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { + println("First = " + first + ", second = " + second) + new CoGroupedRDD[K]( + Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]), + part + ).asInstanceOf[RDD[(K, Seq[Seq[V]])]] + } + +} \ No newline at end of file -- cgit v1.2.3 From c36ca10241991d46f2f1513b2c0c5e369d8b34f9 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 3 Nov 2012 15:19:41 -0700 Subject: Adding locality aware parallelize --- core/src/main/scala/spark/ParallelCollection.scala | 11 +++++++++-- core/src/main/scala/spark/SparkContext.scala | 10 +++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9725017b61..4bd9e1bd54 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -2,6 +2,7 @@ package spark import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer +import scala.collection.Map private[spark] class ParallelCollectionSplit[T: ClassManifest]( val rddId: Long, @@ -24,7 +25,8 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( private[spark] class ParallelCollection[T: ClassManifest]( @transient sc : SparkContext, @transient data: Seq[T], - numSlices: Int) + numSlices: Int, + locationPrefs : Map[Int,Seq[String]]) extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split @@ -40,7 +42,12 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator - override def preferredLocations(s: Split): Seq[String] = Nil + override def preferredLocations(s: Split): Seq[String] = { + locationPrefs.get(splits_.indexOf(s)) match { + case Some(s) => s + case _ => Nil + } + } } private object ParallelCollection { diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d7b46bee38..7ae1aea993 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -194,7 +194,7 @@ class SparkContext( /** Distribute a local Scala collection to form an RDD. */ def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { - new ParallelCollection[T](this, seq, numSlices) + new ParallelCollection[T](this, seq, numSlices, Map[Int, Seq[String]]()) } /** Distribute a local Scala collection to form an RDD. */ @@ -202,6 +202,14 @@ class SparkContext( parallelize(seq, numSlices) } + /** Distribute a local Scala collection to form an RDD, with one or more + * location preferences for each object. Create a new partition for each + * collection item. */ + def makeLocalityConstrainedRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { + val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap + new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs) + } + /** * Read a text file from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI, and return it as an RDD of Strings. -- cgit v1.2.3 From 3e796bdd57297134ed40b20d7692cd9c8cd6efba Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 7 Dec 2012 19:16:35 -0800 Subject: Changes in response to TD's review. --- core/src/main/scala/spark/SparkContext.scala | 6 +++--- .../scala/spark/streaming/FlumeInputDStream.scala | 2 +- .../spark/streaming/NetworkInputDStream.scala | 4 ++-- .../spark/streaming/NetworkInputTracker.scala | 10 ++++----- .../scala/spark/streaming/RawInputDStream.scala | 2 +- .../scala/spark/streaming/SocketInputDStream.scala | 2 +- .../spark/streaming/examples/FlumeEventCount.scala | 24 +++++++++++++++++----- 7 files changed, 32 insertions(+), 18 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 7ae1aea993..3ccdbfe10e 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -203,9 +203,9 @@ class SparkContext( } /** Distribute a local Scala collection to form an RDD, with one or more - * location preferences for each object. Create a new partition for each - * collection item. */ - def makeLocalityConstrainedRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { + * location preferences (hostnames of Spark nodes) for each object. + * Create a new partition for each collection item. */ + def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs) } diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala index 9c403278c3..2959ce4540 100644 --- a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala @@ -126,5 +126,5 @@ class FlumeReceiver( logInfo("Flume receiver stopped") } - override def getLocationConstraint = Some(host) + override def getLocationPreference = Some(host) } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index 052fc8bb74..4e4e9fc942 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -62,8 +62,8 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri /** This method will be called to stop receiving data. */ protected def onStop() - /** This method conveys a placement constraint (hostname) for this receiver. */ - def getLocationConstraint() : Option[String] = None + /** This method conveys a placement preference (hostname) for this receiver. */ + def getLocationPreference() : Option[String] = None /** * This method starts the receiver. First is accesses all the lazy members to diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 56661c2615..b421f795ee 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -99,13 +99,13 @@ class NetworkInputTracker( def startReceivers() { val receivers = networkInputStreams.map(_.createReceiver()) - // We only honor constraints if all receivers have them - val hasLocationConstraints = receivers.map(_.getLocationConstraint().isDefined).reduce(_ && _) + // Right now, we only honor preferences if all receivers have them + val hasLocationPreferences = receivers.map(_.getLocationPreference().isDefined).reduce(_ && _) val tempRDD = - if (hasLocationConstraints) { - val receiversWithConstraints = receivers.map(r => (r, Seq(r.getLocationConstraint().toString))) - ssc.sc.makeLocalityConstrainedRDD[NetworkReceiver[_]](receiversWithConstraints) + if (hasLocationPreferences) { + val receiversWithPreferences = receivers.map(r => (r, Seq(r.getLocationPreference().toString))) + ssc.sc.makeRDD[NetworkReceiver[_]](receiversWithPreferences) } else { ssc.sc.makeRDD(receivers, receivers.size) diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index fd51ed47a5..6acaa9aab1 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -31,7 +31,7 @@ class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: S var blockPushingThread: Thread = null - override def getLocationConstraint = None + override def getLocationPreference = None def onStart() { // Open a socket to the target address and keep reading from it diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala index ebbb17a39a..a9e37c0ff0 100644 --- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala @@ -34,7 +34,7 @@ class SocketReceiver[T: ClassManifest]( lazy protected val dataHandler = new DataHandler(this, storageLevel) - override def getLocationConstraint = None + override def getLocationPreference = None protected def onStart() { logInfo("Connecting to " + host + ":" + port) diff --git a/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala index d76c92fdd5..e60ce483a3 100644 --- a/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala +++ b/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala @@ -4,19 +4,33 @@ import spark.util.IntParam import spark.storage.StorageLevel import spark.streaming._ +/** + * Produce a streaming count of events received from Flume. + * + * This should be used in conjunction with an AvroSink in Flume. It will start + * an Avro server on at the request host:port address and listen for requests. + * Your Flume AvroSink should be pointed to this address. + * + * Usage: FlumeEventCount + * + * is a Spark master URL + * is the host the Flume receiver will be started on - a receiver + * creates a server and listens for flume events. + * is the port the Flume receiver will listen on. + */ object FlumeEventCount { def main(args: Array[String]) { - if (args.length != 4) { + if (args.length != 3) { System.err.println( - "Usage: FlumeEventCount ") + "Usage: FlumeEventCount ") System.exit(1) } - val Array(master, host, IntParam(port), IntParam(batchMillis)) = args + val Array(master, host, IntParam(port)) = args + val batchInterval = Milliseconds(2000) // Create the context and set the batch size - val ssc = new StreamingContext(master, "FlumeEventCount", - Milliseconds(batchMillis)) + val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval) // Create a flume stream val stream = ssc.flumeStream(host,port,StorageLevel.MEMORY_ONLY) -- cgit v1.2.3 From e42721601898ff199ca1c6cfeae159ad3ef691e3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 8 Dec 2012 12:46:59 -0800 Subject: Removed unnecessary testcases. --- core/src/test/scala/spark/CheckpointSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 8622ce92aa..41d84cb01c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -121,7 +121,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(bytesAfterCheckpoint.size < bytesBeforeCheckpoint.size, "CoGroupedSplits still holds on to the splits of its parent RDDs") } - + /* /** * This test forces two ResultTasks of the same job to be launched before and after * the checkpointing of job's RDD is completed. @@ -149,7 +149,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { } testThreading(op1, op2) } - + */ def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { val parCollection = sc.makeRDD(1 to 4, 4) @@ -163,7 +163,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(operatedRDD.dependencies.head.rdd != parentRDD) assert(operatedRDD.collect() === result) } - + /* def testThreading[U: ClassManifest, V: ClassManifest](op1: (RDD[Int]) => RDD[U], op2: (RDD[U]) => RDD[V]) { val parCollection = sc.makeRDD(1 to 2, 2) @@ -201,7 +201,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val correctResult = secondRDD.collect() assert(result === correctResult) } - + */ def sleep(rdd: RDD[_]) { val startTime = System.currentTimeMillis() val maxWaitTime = 5000 -- cgit v1.2.3 From 746afc2e6513d5f32f261ec0dbf2823f78a5e960 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 10 Dec 2012 23:36:37 -0800 Subject: Bunch of bug fixes related to checkpointing in RDDs. RDDCheckpointData object is used to lock all serialization and dependency changes for checkpointing. ResultTask converted to Externalizable and serialized RDD is cached like ShuffleMapTask. --- core/src/main/scala/spark/ParallelCollection.scala | 10 +- core/src/main/scala/spark/RDD.scala | 10 +- core/src/main/scala/spark/RDDCheckpointData.scala | 76 +++++++-- core/src/main/scala/spark/SparkContext.scala | 5 +- core/src/main/scala/spark/rdd/BlockRDD.scala | 9 +- core/src/main/scala/spark/rdd/CartesianRDD.scala | 21 ++- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 18 +- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 8 +- core/src/main/scala/spark/rdd/HadoopRDD.scala | 4 + core/src/main/scala/spark/rdd/UnionRDD.scala | 15 +- .../main/scala/spark/scheduler/ResultTask.scala | 95 ++++++++++- .../scala/spark/scheduler/ShuffleMapTask.scala | 21 ++- core/src/test/scala/spark/CheckpointSuite.scala | 187 ++++++++++++++++++--- 13 files changed, 389 insertions(+), 90 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9725017b61..9d12af6912 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -28,10 +28,11 @@ private[spark] class ParallelCollection[T: ClassManifest]( extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split - // instead. UPDATE: With the new changes to enable checkpointing, this an be done. + // instead. + // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. @transient - val splits_ = { + var splits_ : Array[Split] = { val slices = ParallelCollection.slice(data, numSlices).toArray slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray } @@ -41,6 +42,11 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def preferredLocations(s: Split): Seq[String] = Nil + + override def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + } } private object ParallelCollection { diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e9bd131e61..efa03d5185 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -163,7 +163,7 @@ abstract class RDD[T: ClassManifest]( final def iterator(split: Split): Iterator[T] = { if (isCheckpointed) { // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original - checkpointData.iterator(split.index) + checkpointData.iterator(split) } else if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { @@ -556,16 +556,12 @@ abstract class RDD[T: ClassManifest]( @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - synchronized { - oos.defaultWriteObject() - } + oos.defaultWriteObject() } @throws(classOf[IOException]) private def readObject(ois: ObjectInputStream) { - synchronized { - ois.defaultReadObject() - } + ois.defaultReadObject() } } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index eb4482acee..ff2ed4cdfc 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -1,12 +1,20 @@ package spark import org.apache.hadoop.fs.Path +import rdd.CoalescedRDD +import scheduler.{ResultTask, ShuffleMapTask} - +/** + * This class contains all the information of the regarding RDD checkpointing. + */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) -extends Serializable { +extends Logging with Serializable { + /** + * This class manages the state transition of an RDD through checkpointing + * [ Not checkpointed --> marked for checkpointing --> checkpointing in progress --> checkpointed ] + */ class CheckpointState extends Serializable { var state = 0 @@ -20,24 +28,30 @@ extends Serializable { } val cpState = new CheckpointState() - var cpFile: Option[String] = None - var cpRDD: Option[RDD[T]] = None - var cpRDDSplits: Seq[Split] = Nil + @transient var cpFile: Option[String] = None + @transient var cpRDD: Option[RDD[T]] = None + @transient var cpRDDSplits: Seq[Split] = Nil + // Mark the RDD for checkpointing def markForCheckpoint() = { - rdd.synchronized { cpState.mark() } + RDDCheckpointData.synchronized { cpState.mark() } } + // Is the RDD already checkpointed def isCheckpointed() = { - rdd.synchronized { cpState.isCheckpointed } + RDDCheckpointData.synchronized { cpState.isCheckpointed } } + // Get the file to which this RDD was checkpointed to as a Option def getCheckpointFile() = { - rdd.synchronized { cpFile } + RDDCheckpointData.synchronized { cpFile } } + // Do the checkpointing of the RDD. Called after the first job using that RDD is over. def doCheckpoint() { - rdd.synchronized { + // If it is marked for checkpointing AND checkpointing is not already in progress, + // then set it to be in progress, else return + RDDCheckpointData.synchronized { if (cpState.isMarked && !cpState.isInProgress) { cpState.start() } else { @@ -45,24 +59,56 @@ extends Serializable { } } + // Save to file, and reload it as an RDD val file = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString rdd.saveAsObjectFile(file) - val newRDD = rdd.context.objectFile[T](file, rdd.splits.size) - rdd.synchronized { - rdd.changeDependencies(newRDD) + val newRDD = { + val hadoopRDD = rdd.context.objectFile[T](file, rdd.splits.size) + + val oldSplits = rdd.splits.size + val newSplits = hadoopRDD.splits.size + + logDebug("RDD splits = " + oldSplits + " --> " + newSplits) + if (newSplits < oldSplits) { + throw new Exception("# splits after checkpointing is less than before " + + "[" + oldSplits + " --> " + newSplits) + } else if (newSplits > oldSplits) { + new CoalescedRDD(hadoopRDD, rdd.splits.size) + } else { + hadoopRDD + } + } + logDebug("New RDD has " + newRDD.splits.size + " splits") + + // Change the dependencies and splits of the RDD + RDDCheckpointData.synchronized { cpFile = Some(file) cpRDD = Some(newRDD) cpRDDSplits = newRDD.splits + rdd.changeDependencies(newRDD) cpState.finish() + RDDCheckpointData.checkpointCompleted() + logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } } + // Get preferred location of a split after checkpointing def preferredLocations(split: Split) = { - cpRDD.get.preferredLocations(split) + RDDCheckpointData.synchronized { + cpRDD.get.preferredLocations(split) + } } - def iterator(splitIndex: Int): Iterator[T] = { - cpRDD.get.iterator(cpRDDSplits(splitIndex)) + // Get iterator. This is called at the worker nodes. + def iterator(split: Split): Iterator[T] = { + rdd.firstParent[T].iterator(split) + } +} + +private[spark] object RDDCheckpointData { + def checkpointCompleted() { + ShuffleMapTask.clearCache() + ResultTask.clearCache() } } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d7b46bee38..654b1c2eb7 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -40,9 +40,7 @@ import spark.partial.PartialResult import spark.rdd.HadoopRDD import spark.rdd.NewHadoopRDD import spark.rdd.UnionRDD -import spark.scheduler.ShuffleMapTask -import spark.scheduler.DAGScheduler -import spark.scheduler.TaskScheduler +import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} @@ -486,6 +484,7 @@ class SparkContext( clearJars() SparkEnv.set(null) ShuffleMapTask.clearCache() + ResultTask.clearCache() logInfo("Successfully stopped SparkContext") } diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index 590f9eb738..0c8cdd10dd 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -2,7 +2,7 @@ package spark.rdd import scala.collection.mutable.HashMap -import spark.Dependency +import spark.OneToOneDependency import spark.RDD import spark.SparkContext import spark.SparkEnv @@ -17,7 +17,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St extends RDD[T](sc, Nil) { @transient - val splits_ = (0 until blockIds.size).map(i => { + var splits_ : Array[Split] = (0 until blockIds.size).map(i => { new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] }).toArray @@ -43,5 +43,10 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St override def preferredLocations(split: Split) = locations_(split.asInstanceOf[BlockRDDSplit].blockId) + + override def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + } } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 1d753a5168..9975e79b08 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,10 +1,27 @@ package spark.rdd import spark._ +import java.io.{ObjectOutputStream, IOException} private[spark] -class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { +class CartesianSplit( + idx: Int, + @transient rdd1: RDD[_], + @transient rdd2: RDD[_], + s1Index: Int, + s2Index: Int + ) extends Split { + var s1 = rdd1.splits(s1Index) + var s2 = rdd2.splits(s2Index) override val index: Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + s1 = rdd1.splits(s1Index) + s2 = rdd2.splits(s2Index) + oos.defaultWriteObject() + } } private[spark] @@ -23,7 +40,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { val idx = s1.index * numSplitsInRdd2 + s2.index - array(idx) = new CartesianSplit(idx, s1, s2) + array(idx) = new CartesianSplit(idx, rdd1, rdd2, s1.index, s2.index) } array } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 57d472666b..e4e70b13ba 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -20,11 +20,9 @@ private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, va @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - rdd.synchronized { - // Update the reference to parent split at the time of task serialization - split = rdd.splits(splitIndex) - oos.defaultWriteObject() - } + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() } } private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep @@ -42,7 +40,8 @@ private[spark] class CoGroupAggregator { (b1, b2) => b1 ++ b2 }) with Serializable -class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) +class +CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging { val aggr = new CoGroupAggregator @@ -63,7 +62,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) deps.toList } - override def dependencies = deps_ + // Pre-checkpoint dependencies deps_ should be transient (deps_) + // but post-checkpoint dependencies must not be transient (dependencies_) + override def dependencies = if (isCheckpointed) dependencies_ else deps_ @transient var splits_ : Array[Split] = { @@ -114,7 +115,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) } override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + deps_ = null + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) splits_ = newRDD.splits rdds = null } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 0b4499e2eb..088958942e 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -12,11 +12,9 @@ private[spark] case class CoalescedRDDSplit( @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - rdd.synchronized { - // Update the reference to parent split at the time of task serialization - parents = parentsIndices.map(rdd.splits(_)) - oos.defaultWriteObject() - } + // Update the reference to parent split at the time of task serialization + parents = parentsIndices.map(rdd.splits(_)) + oos.defaultWriteObject() } } diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index a12531ea89..af54f23ebc 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -115,4 +115,8 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } + + override def checkpoint() { + // Do nothing. Hadoop RDD cannot be checkpointed. + } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index a5948dd1f1..808729f18d 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -19,11 +19,9 @@ private[spark] class UnionSplit[T: ClassManifest]( @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - rdd.synchronized { - // Update the reference to parent split at the time of task serialization - split = rdd.splits(splitIndex) - oos.defaultWriteObject() - } + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() } } @@ -55,7 +53,9 @@ class UnionRDD[T: ClassManifest]( deps.toList } - override def dependencies = deps_ + // Pre-checkpoint dependencies deps_ should be transient (deps_) + // but post-checkpoint dependencies must not be transient (dependencies_) + override def dependencies = if (isCheckpointed) dependencies_ else deps_ override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() @@ -63,7 +63,8 @@ class UnionRDD[T: ClassManifest]( s.asInstanceOf[UnionSplit[T]].preferredLocations() override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD)) + deps_ = null + dependencies_ = List(new OneToOneDependency(newRDD)) splits_ = newRDD.splits rdds = null } diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index 2ebd4075a2..bcb9e4956b 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -1,17 +1,73 @@ package spark.scheduler import spark._ +import java.io._ +import util.{MetadataCleaner, TimeStampedHashMap} +import java.util.zip.{GZIPInputStream, GZIPOutputStream} + +private[spark] object ResultTask { + + // A simple map between the stage id to the serialized byte array of a task. + // Served as a cache for task serialization because serialization can be + // expensive on the master node if it needs to launch thousands of tasks. + val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] + + val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.cleanup) + + def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { + synchronized { + val old = serializedInfoCache.get(stageId).orNull + if (old != null) { + return old + } else { + val out = new ByteArrayOutputStream + val ser = SparkEnv.get.closureSerializer.newInstance + val objOut = ser.serializeStream(new GZIPOutputStream(out)) + objOut.writeObject(rdd) + objOut.writeObject(func) + objOut.close() + val bytes = out.toByteArray + serializedInfoCache.put(stageId, bytes) + return bytes + } + } + } + + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { + synchronized { + val loader = Thread.currentThread.getContextClassLoader + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val ser = SparkEnv.get.closureSerializer.newInstance + val objIn = ser.deserializeStream(in) + val rdd = objIn.readObject().asInstanceOf[RDD[_]] + val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] + return (rdd, func) + } + } + + def clearCache() { + synchronized { + serializedInfoCache.clear() + } + } +} + private[spark] class ResultTask[T, U]( stageId: Int, - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - val partition: Int, + var rdd: RDD[T], + var func: (TaskContext, Iterator[T]) => U, + var partition: Int, @transient locs: Seq[String], val outputId: Int) - extends Task[U](stageId) { - - val split = rdd.splits(partition) + extends Task[U](stageId) with Externalizable { + + def this() = this(0, null, null, 0, null, 0) + var split = if (rdd == null) { + null + } else { + rdd.splits(partition) + } override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) @@ -21,4 +77,31 @@ private[spark] class ResultTask[T, U]( override def preferredLocations: Seq[String] = locs override def toString = "ResultTask(" + stageId + ", " + partition + ")" + + override def writeExternal(out: ObjectOutput) { + RDDCheckpointData.synchronized { + split = rdd.splits(partition) + out.writeInt(stageId) + val bytes = ResultTask.serializeInfo( + stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partition) + out.writeInt(outputId) + out.writeObject(split) + } + } + + override def readExternal(in: ObjectInput) { + val stageId = in.readInt() + val numBytes = in.readInt() + val bytes = new Array[Byte](numBytes) + in.readFully(bytes) + val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) + rdd = rdd_.asInstanceOf[RDD[T]] + func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] + partition = in.readInt() + val outputId = in.readInt() + split = in.readObject().asInstanceOf[Split] + } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 683f5ebec3..5d28c40778 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -84,19 +84,22 @@ private[spark] class ShuffleMapTask( def this() = this(0, null, null, 0, null) var split = if (rdd == null) { - null - } else { + null + } else { rdd.splits(partition) } override def writeExternal(out: ObjectOutput) { - out.writeInt(stageId) - val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partition) - out.writeLong(generation) - out.writeObject(split) + RDDCheckpointData.synchronized { + split = rdd.splits(partition) + out.writeInt(stageId) + val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partition) + out.writeLong(generation) + out.writeObject(split) + } } override def readExternal(in: ObjectInput) { diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 51bd59e2b1..7b323e089c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -34,13 +34,30 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { } } + test("RDDs with one-to-one dependencies") { + testCheckpointing(_.map(x => x.toString)) + testCheckpointing(_.flatMap(x => 1 to x)) + testCheckpointing(_.filter(_ % 2 == 0)) + testCheckpointing(_.sample(false, 0.5, 0)) + testCheckpointing(_.glom()) + testCheckpointing(_.mapPartitions(_.map(_.toString))) + testCheckpointing(r => new MapPartitionsWithSplitRDD(r, + (i: Int, iter: Iterator[Int]) => iter.map(_.toString) )) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) + testCheckpointing(_.pipe(Seq("cat"))) + } + test("ParallelCollection") { - val parCollection = sc.makeRDD(1 to 4) + val parCollection = sc.makeRDD(1 to 4, 2) + val numSplits = parCollection.splits.size parCollection.checkpoint() assert(parCollection.dependencies === Nil) val result = parCollection.collect() assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) + assert(parCollection.splits.length === numSplits) + assert(parCollection.splits.toList === parCollection.checkpointData.cpRDDSplits.toList) assert(parCollection.collect() === result) } @@ -49,44 +66,58 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val blockManager = SparkEnv.get.blockManager blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) val blockRDD = new BlockRDD[String](sc, Array(blockId)) + val numSplits = blockRDD.splits.size blockRDD.checkpoint() val result = blockRDD.collect() assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) + assert(blockRDD.splits.length === numSplits) + assert(blockRDD.splits.toList === blockRDD.checkpointData.cpRDDSplits.toList) assert(blockRDD.collect() === result) } - test("RDDs with one-to-one dependencies") { - testCheckpointing(_.map(x => x.toString)) - testCheckpointing(_.flatMap(x => 1 to x)) - testCheckpointing(_.filter(_ % 2 == 0)) - testCheckpointing(_.sample(false, 0.5, 0)) - testCheckpointing(_.glom()) - testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithSplitRDD(r, - (i: Int, iter: Iterator[Int]) => iter.map(_.toString) )) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) - testCheckpointing(_.pipe(Seq("cat"))) - } - test("ShuffledRDD") { - // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD testCheckpointing(rdd => { + // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD new ShuffledRDD(rdd.map(x => (x % 2, 1)), partitioner) }) } test("UnionRDD") { def otherRDD = sc.makeRDD(1 to 10, 4) + + // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed. + // Current implementation of UnionRDD has transient reference to parent RDDs, + // so only the splits will reduce in serialized size, not the RDD. testCheckpointing(_.union(otherRDD), false, true) testParentCheckpointing(_.union(otherRDD), false, true) } test("CartesianRDD") { - def otherRDD = sc.makeRDD(1 to 10, 4) - testCheckpointing(_.cartesian(otherRDD)) - testParentCheckpointing(_.cartesian(otherRDD), true, false) + def otherRDD = sc.makeRDD(1 to 10, 1) + testCheckpointing(new CartesianRDD(sc, _, otherRDD)) + + // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed + // Current implementation of CoalescedRDDSplit has transient reference to parent RDD, + // so only the RDD will reduce in serialized size, not the splits. + testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false) + + // Test that the CartesianRDD updates parent splits (CartesianRDD.s1/s2) after + // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. + // Note that this test is very specific to the current implementation of CartesianRDD. + val ones = sc.makeRDD(1 to 100, 10).map(x => x) + ones.checkpoint // checkpoint that MappedRDD + val cartesian = new CartesianRDD(sc, ones, ones) + val splitBeforeCheckpoint = + serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit]) + cartesian.count() // do the checkpointing + val splitAfterCheckpoint = + serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit]) + assert( + (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) && + (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2), + "CartesianRDD.parents not updated after parent RDD checkpointed" + ) } test("CoalescedRDD") { @@ -94,7 +125,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed // Current implementation of CoalescedRDDSplit has transient reference to parent RDD, - // so does not serialize the RDD (not need to check its size). + // so only the RDD will reduce in serialized size, not the splits. testParentCheckpointing(new CoalescedRDD(_, 2), true, false) // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after @@ -145,13 +176,14 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val operatedRDD = op(baseRDD) val parentRDD = operatedRDD.dependencies.headOption.orNull val rddType = operatedRDD.getClass.getSimpleName + val numSplits = operatedRDD.splits.length // Find serialized sizes before and after the checkpoint val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) operatedRDD.checkpoint() val result = operatedRDD.collect() val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - + // Test whether the checkpoint file has been created assert(sc.objectFile[U](operatedRDD.getCheckpointFile.get).collect() === result) @@ -160,6 +192,9 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // Test whether the splits have been changed to the new Hadoop splits assert(operatedRDD.splits.toList === operatedRDD.checkpointData.cpRDDSplits.toList) + + // Test whether the number of splits is same as before + assert(operatedRDD.splits.length === numSplits) // Test whether the data in the checkpointed RDD is same as original assert(operatedRDD.collect() === result) @@ -168,7 +203,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // does not have any dependency to another RDD (e.g., ParallelCollection, // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing. if (testRDDSize) { - println("Size of " + rddType + + logInfo("Size of " + rddType + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") assert( rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, @@ -184,7 +219,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // must be forgotten after checkpointing (to remove all reference to parent RDDs) and // replaced with the HadoopSplits of the checkpointed RDD. if (testRDDSplitSize) { - println("Size of " + rddType + " splits " + logInfo("Size of " + rddType + " splits " + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]") assert( splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, @@ -294,14 +329,118 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val bytes = Utils.serialize(obj) Utils.deserialize[T](bytes) } + /* + test("Consistency check for ResultTask") { + // Time -----------------------> + // Core 1: |<- count in thread 1, task 1 ->| |<-- checkpoint, task 1 ---->| |<- count in thread 2, task 2 ->| + // Core 2: |<- count in thread 1, task 2 ->| |<--- checkpoint, task 2 ---------->| |<- count in thread 2, task 1 ->| + // | + // checkpoint completed + sc.stop(); sc = null + System.clearProperty("spark.master.port") + + val dir = File.createTempFile("temp_", "") + dir.delete() + val ctxt = new SparkContext("local[2]", "ResultTask") + ctxt.setCheckpointDir(dir.toString) + + try { + val rdd = ctxt.makeRDD(1 to 2, 2).map(x => { + val state = CheckpointSuite.incrementState() + println("State = " + state) + if (state <= 3) { + // If executing the two tasks for the job comouting rdd.count + // of thread 1, or the first task for the recomputation due + // to checkpointing (saveing to HDFS), then do nothing + } else if (state == 4) { + // If executing the second task for the recomputation due to + // checkpointing. then prolong this task, to allow rdd.count + // of thread 2 to start before checkpoint of this RDD is completed + + Thread.sleep(1000) + println("State = " + state + " wake up") + } else { + // Else executing the tasks from thread 2 + Thread.sleep(1000) + println("State = " + state + " wake up") + } + + (x, 1) + }) + rdd.checkpoint() + val env = SparkEnv.get + + val thread1 = new Thread() { + override def run() { + try { + SparkEnv.set(env) + rdd.count() + } catch { + case e: Exception => CheckpointSuite.failed("Exception in thread 1", e) + } + } + } + thread1.start() + + val thread2 = new Thread() { + override def run() { + try { + SparkEnv.set(env) + CheckpointSuite.waitTillState(3) + println("\n\n\n\n") + rdd.count() + } catch { + case e: Exception => CheckpointSuite.failed("Exception in thread 2", e) + } + } + } + thread2.start() + + thread1.join() + thread2.join() + } finally { + dir.delete() + } + + assert(!CheckpointSuite.failed, CheckpointSuite.failureMessage) + + ctxt.stop() + + } + */ } object CheckpointSuite { + /* + var state = 0 + var failed = false + var failureMessage = "" + + def incrementState(): Int = { + this.synchronized { state += 1; this.notifyAll(); state } + } + + def getState(): Int = { + this.synchronized( state ) + } + + def waitTillState(s: Int) { + while(state < s) { + this.synchronized { this.wait() } + } + } + + def failed(msg: String, ex: Exception) { + failed = true + failureMessage += msg + "\n" + ex + "\n\n" + } + */ + // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { - println("First = " + first + ", second = " + second) + //println("First = " + first + ", second = " + second) new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]), part -- cgit v1.2.3 From 2a87d816a24c62215d682e3a7af65489c0d6e708 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Dec 2012 01:44:43 -0800 Subject: Added clear property to JavaAPISuite to remove port binding errors. --- core/src/test/scala/spark/JavaAPISuite.java | 2 ++ 1 file changed, 2 insertions(+) (limited to 'core') diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 5875506179..6bd9836a93 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -44,6 +44,8 @@ public class JavaAPISuite implements Serializable { public void tearDown() { sc.stop(); sc = null; + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port"); } static class ReverseIntComparator implements Comparator, Serializable { -- cgit v1.2.3 From fa28f25619d6712e5f920f498ec03085ea208b4d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Dec 2012 13:59:43 -0800 Subject: Fixed bug in UnionRDD and CoGroupedRDD --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 9 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 12 +-- core/src/test/scala/spark/CheckpointSuite.scala | 104 ----------------------- 3 files changed, 10 insertions(+), 115 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index e4e70b13ba..bc6d16ee8b 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -15,8 +15,11 @@ import spark.Split import java.io.{ObjectOutputStream, IOException} private[spark] sealed trait CoGroupSplitDep extends Serializable -private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, var split: Split = null) - extends CoGroupSplitDep { +private[spark] case class NarrowCoGroupSplitDep( + rdd: RDD[_], + splitIndex: Int, + var split: Split + ) extends CoGroupSplitDep { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { @@ -75,7 +78,7 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) case s: ShuffleDependency[_, _] => new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep case _ => - new NarrowCoGroupSplitDep(r, i): CoGroupSplitDep + new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep } }.toList) } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 808729f18d..a84867492b 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -5,14 +5,10 @@ import scala.collection.mutable.ArrayBuffer import spark._ import java.io.{ObjectOutputStream, IOException} -private[spark] class UnionSplit[T: ClassManifest]( - idx: Int, - rdd: RDD[T], - splitIndex: Int, - var split: Split = null) - extends Split - with Serializable { - +private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int) + extends Split { + var split: Split = rdd.splits(splitIndex) + def iterator() = rdd.iterator(split) def preferredLocations() = rdd.preferredLocations(split) override val index: Int = idx diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 7b323e089c..909c55c91c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -329,114 +329,10 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val bytes = Utils.serialize(obj) Utils.deserialize[T](bytes) } - /* - test("Consistency check for ResultTask") { - // Time -----------------------> - // Core 1: |<- count in thread 1, task 1 ->| |<-- checkpoint, task 1 ---->| |<- count in thread 2, task 2 ->| - // Core 2: |<- count in thread 1, task 2 ->| |<--- checkpoint, task 2 ---------->| |<- count in thread 2, task 1 ->| - // | - // checkpoint completed - sc.stop(); sc = null - System.clearProperty("spark.master.port") - - val dir = File.createTempFile("temp_", "") - dir.delete() - val ctxt = new SparkContext("local[2]", "ResultTask") - ctxt.setCheckpointDir(dir.toString) - - try { - val rdd = ctxt.makeRDD(1 to 2, 2).map(x => { - val state = CheckpointSuite.incrementState() - println("State = " + state) - if (state <= 3) { - // If executing the two tasks for the job comouting rdd.count - // of thread 1, or the first task for the recomputation due - // to checkpointing (saveing to HDFS), then do nothing - } else if (state == 4) { - // If executing the second task for the recomputation due to - // checkpointing. then prolong this task, to allow rdd.count - // of thread 2 to start before checkpoint of this RDD is completed - - Thread.sleep(1000) - println("State = " + state + " wake up") - } else { - // Else executing the tasks from thread 2 - Thread.sleep(1000) - println("State = " + state + " wake up") - } - - (x, 1) - }) - rdd.checkpoint() - val env = SparkEnv.get - - val thread1 = new Thread() { - override def run() { - try { - SparkEnv.set(env) - rdd.count() - } catch { - case e: Exception => CheckpointSuite.failed("Exception in thread 1", e) - } - } - } - thread1.start() - - val thread2 = new Thread() { - override def run() { - try { - SparkEnv.set(env) - CheckpointSuite.waitTillState(3) - println("\n\n\n\n") - rdd.count() - } catch { - case e: Exception => CheckpointSuite.failed("Exception in thread 2", e) - } - } - } - thread2.start() - - thread1.join() - thread2.join() - } finally { - dir.delete() - } - - assert(!CheckpointSuite.failed, CheckpointSuite.failureMessage) - - ctxt.stop() - - } - */ } object CheckpointSuite { - /* - var state = 0 - var failed = false - var failureMessage = "" - - def incrementState(): Int = { - this.synchronized { state += 1; this.notifyAll(); state } - } - - def getState(): Int = { - this.synchronized( state ) - } - - def waitTillState(s: Int) { - while(state < s) { - this.synchronized { this.wait() } - } - } - - def failed(msg: String, ex: Exception) { - failed = true - failureMessage += msg + "\n" + ex + "\n\n" - } - */ - // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { -- cgit v1.2.3 From 8e74fac215e8b9cda7e35111c5116e3669c6eb97 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Dec 2012 15:36:12 -0800 Subject: Made checkpoint data in RDDs optional to further reduce serialized size. --- core/src/main/scala/spark/RDD.scala | 19 +++++++++++-------- core/src/main/scala/spark/SparkContext.scala | 11 +++++++++++ core/src/test/scala/spark/CheckpointSuite.scala | 12 ++++++------ .../src/main/scala/spark/streaming/DStream.scala | 4 +--- 4 files changed, 29 insertions(+), 17 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index efa03d5185..6c04769c82 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -112,7 +112,7 @@ abstract class RDD[T: ClassManifest]( // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE - protected[spark] val checkpointData = new RDDCheckpointData(this) + protected[spark] var checkpointData: Option[RDDCheckpointData[T]] = None /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassManifest] = { @@ -149,7 +149,7 @@ abstract class RDD[T: ClassManifest]( def getPreferredLocations(split: Split) = { if (isCheckpointed) { - checkpointData.preferredLocations(split) + checkpointData.get.preferredLocations(split) } else { preferredLocations(split) } @@ -163,7 +163,7 @@ abstract class RDD[T: ClassManifest]( final def iterator(split: Split): Iterator[T] = { if (isCheckpointed) { // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original - checkpointData.iterator(split) + checkpointData.get.iterator(split) } else if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { @@ -516,21 +516,24 @@ abstract class RDD[T: ClassManifest]( * require recomputation. */ def checkpoint() { - checkpointData.markForCheckpoint() + if (checkpointData.isEmpty) { + checkpointData = Some(new RDDCheckpointData(this)) + checkpointData.get.markForCheckpoint() + } } /** * Return whether this RDD has been checkpointed or not */ def isCheckpointed(): Boolean = { - checkpointData.isCheckpointed() + if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false } /** * Gets the name of the file to which this RDD was checkpointed */ def getCheckpointFile(): Option[String] = { - checkpointData.getCheckpointFile() + if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None } /** @@ -539,12 +542,12 @@ abstract class RDD[T: ClassManifest]( * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. */ protected[spark] def doCheckpoint() { - checkpointData.doCheckpoint() + if (checkpointData.isDefined) checkpointData.get.doCheckpoint() dependencies.foreach(_.rdd.doCheckpoint()) } /** - * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]] + * Changes the dependencies of this RDD from its original parents to the new RDD * (`newRDD`) created from the checkpoint file. This method must ensure that all references * to the original parent RDDs must be removed to enable the parent RDDs to be garbage * collected. Subclasses of RDD may override this method for implementing their own changing diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 654b1c2eb7..71ed4ef058 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -366,6 +366,17 @@ class SparkContext( .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes)) } + + protected[spark] def checkpointFile[T: ClassManifest]( + path: String, + minSplits: Int = defaultMinSplits + ): RDD[T] = { + val rdd = objectFile[T](path, minSplits) + rdd.checkpointData = Some(new RDDCheckpointData(rdd)) + rdd.checkpointData.get.cpFile = Some(path) + rdd + } + /** Build the union of a list of RDDs. */ def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 909c55c91c..0bffedb8db 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -57,7 +57,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.splits.length === numSplits) - assert(parCollection.splits.toList === parCollection.checkpointData.cpRDDSplits.toList) + assert(parCollection.splits.toList === parCollection.checkpointData.get.cpRDDSplits.toList) assert(parCollection.collect() === result) } @@ -72,7 +72,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.splits.length === numSplits) - assert(blockRDD.splits.toList === blockRDD.checkpointData.cpRDDSplits.toList) + assert(blockRDD.splits.toList === blockRDD.checkpointData.get.cpRDDSplits.toList) assert(blockRDD.collect() === result) } @@ -84,7 +84,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { } test("UnionRDD") { - def otherRDD = sc.makeRDD(1 to 10, 4) + def otherRDD = sc.makeRDD(1 to 10, 1) // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed. // Current implementation of UnionRDD has transient reference to parent RDDs, @@ -191,7 +191,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(operatedRDD.dependencies.head.rdd != parentRDD) // Test whether the splits have been changed to the new Hadoop splits - assert(operatedRDD.splits.toList === operatedRDD.checkpointData.cpRDDSplits.toList) + assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.cpRDDSplits.toList) // Test whether the number of splits is same as before assert(operatedRDD.splits.length === numSplits) @@ -289,8 +289,8 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { */ def generateLongLineageRDD(): RDD[Int] = { var rdd = sc.makeRDD(1 to 100, 4) - for (i <- 1 to 20) { - rdd = rdd.map(x => x) + for (i <- 1 to 50) { + rdd = rdd.map(x => x + 1) } rdd } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index d290c5927e..69fefa21a0 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -372,9 +372,7 @@ extends Serializable with Logging { checkpointData.foreach { case(time, data) => { logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'") - val rdd = ssc.sc.objectFile[T](data.toString) - // Set the checkpoint file name to identify this RDD as a checkpointed RDD by updateCheckpointData() - rdd.checkpointData.cpFile = Some(data.toString) + val rdd = ssc.sc.checkpointFile[T](data.toString) generatedRDDs += ((time, rdd)) } } -- cgit v1.2.3 From 72eed2b95edb3b0b213517c815e09c3886b11669 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 17 Dec 2012 18:52:43 -0800 Subject: Converted CheckpointState in RDDCheckpointData to use scala Enumeration. --- core/src/main/scala/spark/RDDCheckpointData.scala | 48 +++++++++++------------ 1 file changed, 22 insertions(+), 26 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index ff2ed4cdfc..7613b338e6 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -5,45 +5,41 @@ import rdd.CoalescedRDD import scheduler.{ResultTask, ShuffleMapTask} /** - * This class contains all the information of the regarding RDD checkpointing. + * Enumeration to manage state transitions of an RDD through checkpointing + * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] */ +private[spark] object CheckpointState extends Enumeration { + type CheckpointState = Value + val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value +} +/** + * This class contains all the information of the regarding RDD checkpointing. + */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) extends Logging with Serializable { - /** - * This class manages the state transition of an RDD through checkpointing - * [ Not checkpointed --> marked for checkpointing --> checkpointing in progress --> checkpointed ] - */ - class CheckpointState extends Serializable { - var state = 0 + import CheckpointState._ - def mark() { if (state == 0) state = 1 } - def start() { assert(state == 1); state = 2 } - def finish() { assert(state == 2); state = 3 } - - def isMarked() = { state == 1 } - def isInProgress = { state == 2 } - def isCheckpointed = { state == 3 } - } - - val cpState = new CheckpointState() + var cpState = Initialized @transient var cpFile: Option[String] = None @transient var cpRDD: Option[RDD[T]] = None @transient var cpRDDSplits: Seq[Split] = Nil // Mark the RDD for checkpointing - def markForCheckpoint() = { - RDDCheckpointData.synchronized { cpState.mark() } + def markForCheckpoint() { + RDDCheckpointData.synchronized { + if (cpState == Initialized) cpState = MarkedForCheckpoint + } } // Is the RDD already checkpointed - def isCheckpointed() = { - RDDCheckpointData.synchronized { cpState.isCheckpointed } + def isCheckpointed(): Boolean = { + RDDCheckpointData.synchronized { cpState == Checkpointed } } - // Get the file to which this RDD was checkpointed to as a Option - def getCheckpointFile() = { + // Get the file to which this RDD was checkpointed to as an Option + def getCheckpointFile(): Option[String] = { RDDCheckpointData.synchronized { cpFile } } @@ -52,8 +48,8 @@ extends Logging with Serializable { // If it is marked for checkpointing AND checkpointing is not already in progress, // then set it to be in progress, else return RDDCheckpointData.synchronized { - if (cpState.isMarked && !cpState.isInProgress) { - cpState.start() + if (cpState == MarkedForCheckpoint) { + cpState = CheckpointingInProgress } else { return } @@ -87,7 +83,7 @@ extends Logging with Serializable { cpRDD = Some(newRDD) cpRDDSplits = newRDD.splits rdd.changeDependencies(newRDD) - cpState.finish() + cpState = Checkpointed RDDCheckpointData.checkpointCompleted() logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } -- cgit v1.2.3 From 5184141936c18f12c6738caae6fceee4d15800e2 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 18 Dec 2012 13:30:53 -0800 Subject: Introduced getSpits, getDependencies, and getPreferredLocations in RDD and RDDCheckpointData. --- core/src/main/scala/spark/PairRDDFunctions.scala | 4 +- core/src/main/scala/spark/ParallelCollection.scala | 9 +- core/src/main/scala/spark/RDD.scala | 123 +++++++++++++-------- core/src/main/scala/spark/RDDCheckpointData.scala | 10 +- core/src/main/scala/spark/rdd/BlockRDD.scala | 9 +- core/src/main/scala/spark/rdd/CartesianRDD.scala | 12 +- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 11 +- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 10 +- core/src/main/scala/spark/rdd/FilteredRDD.scala | 2 +- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 2 +- core/src/main/scala/spark/rdd/GlommedRDD.scala | 2 +- core/src/main/scala/spark/rdd/HadoopRDD.scala | 4 +- .../main/scala/spark/rdd/MapPartitionsRDD.scala | 2 +- .../spark/rdd/MapPartitionsWithSplitRDD.scala | 2 +- core/src/main/scala/spark/rdd/MappedRDD.scala | 2 +- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 4 +- core/src/main/scala/spark/rdd/PipedRDD.scala | 2 +- core/src/main/scala/spark/rdd/SampledRDD.scala | 9 +- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 7 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 13 +-- .../main/scala/spark/scheduler/DAGScheduler.scala | 2 +- core/src/test/scala/spark/CheckpointSuite.scala | 6 +- 22 files changed, 134 insertions(+), 113 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 1f82bd3ab8..09ac606cfb 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -628,7 +628,7 @@ private[spark] class MappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => U) extends RDD[(K, U)](prev.get) { - override def splits = firstParent[(K, V)].splits + override def getSplits = firstParent[(K, V)].splits override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split) = firstParent[(K, V)].iterator(split).map{case (k, v) => (k, f(v))} } @@ -637,7 +637,7 @@ private[spark] class FlatMappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) extends RDD[(K, U)](prev.get) { - override def splits = firstParent[(K, V)].splits + override def getSplits = firstParent[(K, V)].splits override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split) = { firstParent[(K, V)].iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9d12af6912..0bc5b2ff11 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -37,15 +37,12 @@ private[spark] class ParallelCollection[T: ClassManifest]( slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray } - override def splits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_.asInstanceOf[Array[Split]] override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator - - override def preferredLocations(s: Split): Seq[String] = Nil - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 6c04769c82..f3e422fa5f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -81,48 +81,33 @@ abstract class RDD[T: ClassManifest]( def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) - // Methods that must be implemented by subclasses: - - /** Set of partitions in this RDD. */ - def splits: Array[Split] + // ======================================================================= + // Methods that should be implemented by subclasses of RDD + // ======================================================================= /** Function for computing a given partition. */ def compute(split: Split): Iterator[T] - /** How this RDD depends on any parent RDDs. */ - def dependencies: List[Dependency[_]] = dependencies_ + /** Set of partitions in this RDD. */ + protected def getSplits(): Array[Split] - /** Record user function generating this RDD. */ - private[spark] val origin = Utils.getSparkCallSite - - /** Optionally overridden by subclasses to specify how they are partitioned. */ - val partitioner: Option[Partitioner] = None + /** How this RDD depends on any parent RDDs. */ + protected def getDependencies(): List[Dependency[_]] = dependencies_ /** Optionally overridden by subclasses to specify placement preferences. */ - def preferredLocations(split: Split): Seq[String] = Nil - - /** The [[spark.SparkContext]] that this RDD was created on. */ - def context = sc + protected def getPreferredLocations(split: Split): Seq[String] = Nil - private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] - - /** A unique ID for this RDD (within its SparkContext). */ - val id = sc.newRddId() - - // Variables relating to persistence - private var storageLevel: StorageLevel = StorageLevel.NONE + /** Optionally overridden by subclasses to specify how they are partitioned. */ + val partitioner: Option[Partitioner] = None - protected[spark] var checkpointData: Option[RDDCheckpointData[T]] = None - /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassManifest] = { - dependencies.head.rdd.asInstanceOf[RDD[U]] - } - /** Returns the `i` th parent RDD */ - protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] + // ======================================================================= + // Methods and fields available on all RDDs + // ======================================================================= - // Methods available on all RDDs: + /** A unique ID for this RDD (within its SparkContext). */ + val id = sc.newRddId() /** * Set this RDD's storage level to persist its values across operations after the first time @@ -147,11 +132,39 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - def getPreferredLocations(split: Split) = { + /** + * Get the preferred location of a split, taking into account whether the + * RDD is checkpointed or not. + */ + final def preferredLocations(split: Split): Seq[String] = { + if (isCheckpointed) { + checkpointData.get.getPreferredLocations(split) + } else { + getPreferredLocations(split) + } + } + + /** + * Get the array of splits of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def splits: Array[Split] = { + if (isCheckpointed) { + checkpointData.get.getSplits + } else { + getSplits + } + } + + /** + * Get the array of splits of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def dependencies: List[Dependency[_]] = { if (isCheckpointed) { - checkpointData.get.preferredLocations(split) + dependencies_ } else { - preferredLocations(split) + getDependencies } } @@ -536,6 +549,27 @@ abstract class RDD[T: ClassManifest]( if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None } + // ======================================================================= + // Other internal methods and fields + // ======================================================================= + + private var storageLevel: StorageLevel = StorageLevel.NONE + + /** Record user function generating this RDD. */ + private[spark] val origin = Utils.getSparkCallSite + + private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] + + private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None + + /** Returns the first parent RDD */ + protected[spark] def firstParent[U: ClassManifest] = { + dependencies.head.rdd.asInstanceOf[RDD[U]] + } + + /** The [[spark.SparkContext]] that this RDD was created on. */ + def context = sc + /** * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler * after a job using this RDD has completed (therefore the RDD has been materialized and @@ -548,23 +582,18 @@ abstract class RDD[T: ClassManifest]( /** * Changes the dependencies of this RDD from its original parents to the new RDD - * (`newRDD`) created from the checkpoint file. This method must ensure that all references - * to the original parent RDDs must be removed to enable the parent RDDs to be garbage - * collected. Subclasses of RDD may override this method for implementing their own changing - * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + * (`newRDD`) created from the checkpoint file. */ protected[spark] def changeDependencies(newRDD: RDD[_]) { + clearDependencies() dependencies_ = List(new OneToOneDependency(newRDD)) } - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - oos.defaultWriteObject() - } - - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream) { - ois.defaultReadObject() - } - + /** + * Clears the dependencies of this RDD. This method must ensure that all references + * to the original parent RDDs must be removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own changing + * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + */ + protected[spark] def clearDependencies() { } } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index 7613b338e6..e4c0912cdc 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -24,7 +24,6 @@ extends Logging with Serializable { var cpState = Initialized @transient var cpFile: Option[String] = None @transient var cpRDD: Option[RDD[T]] = None - @transient var cpRDDSplits: Seq[Split] = Nil // Mark the RDD for checkpointing def markForCheckpoint() { @@ -81,7 +80,6 @@ extends Logging with Serializable { RDDCheckpointData.synchronized { cpFile = Some(file) cpRDD = Some(newRDD) - cpRDDSplits = newRDD.splits rdd.changeDependencies(newRDD) cpState = Checkpointed RDDCheckpointData.checkpointCompleted() @@ -90,12 +88,18 @@ extends Logging with Serializable { } // Get preferred location of a split after checkpointing - def preferredLocations(split: Split) = { + def getPreferredLocations(split: Split) = { RDDCheckpointData.synchronized { cpRDD.get.preferredLocations(split) } } + def getSplits: Array[Split] = { + RDDCheckpointData.synchronized { + cpRDD.get.splits + } + } + // Get iterator. This is called at the worker nodes. def iterator(split: Split): Iterator[T] = { rdd.firstParent[T].iterator(split) diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index 0c8cdd10dd..68e570eb15 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -29,7 +29,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St HashMap(blockIds.zip(locations):_*) } - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split): Iterator[T] = { val blockManager = SparkEnv.get.blockManager @@ -41,12 +41,11 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St } } - override def preferredLocations(split: Split) = + override def getPreferredLocations(split: Split) = locations_(split.asInstanceOf[BlockRDDSplit].blockId) - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 9975e79b08..116644bd52 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -45,9 +45,9 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( array } - override def splits = splits_ + override def getSplits = splits_ - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { val currSplit = split.asInstanceOf[CartesianSplit] rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) } @@ -66,11 +66,11 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( } ) - override def dependencies = deps_ + override def getDependencies = deps_ - override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + deps_ = Nil + splits_ = null rdd1 = null rdd2 = null } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index bc6d16ee8b..9cc95dc172 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -65,9 +65,7 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) deps.toList } - // Pre-checkpoint dependencies deps_ should be transient (deps_) - // but post-checkpoint dependencies must not be transient (dependencies_) - override def dependencies = if (isCheckpointed) dependencies_ else deps_ + override def getDependencies = deps_ @transient var splits_ : Array[Split] = { @@ -85,7 +83,7 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) array } - override def splits = splits_ + override def getSplits = splits_ override val partitioner = Some(part) @@ -117,10 +115,9 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) map.iterator } - override def changeDependencies(newRDD: RDD[_]) { + override def clearDependencies() { deps_ = null - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + splits_ = null rdds = null } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 088958942e..85d0fa9f6a 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -44,7 +44,7 @@ class CoalescedRDD[T: ClassManifest]( } } - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { @@ -59,11 +59,11 @@ class CoalescedRDD[T: ClassManifest]( } ) - override def dependencies = deps_ + override def getDependencies() = deps_ - override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD)) - splits_ = newRDD.splits + override def clearDependencies() { + deps_ = Nil + splits_ = null prev = null } } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index 02f2e7c246..309ed2399d 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -9,6 +9,6 @@ class FilteredRDD[T: ClassManifest]( f: T => Boolean) extends RDD[T](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).filter(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index cdc8ecdcfe..1160e68bb8 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -9,6 +9,6 @@ class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( f: T => TraversableOnce[U]) extends RDD[U](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).flatMap(f) } diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index df6f61c69d..4fab1a56fa 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -6,6 +6,6 @@ import spark.Split private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index af54f23ebc..fce190b860 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -67,7 +67,7 @@ class HadoopRDD[K, V]( .asInstanceOf[InputFormat[K, V]] } - override def splits = splits_ + override def getSplits = splits_ override def compute(theSplit: Split) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[HadoopSplit] @@ -110,7 +110,7 @@ class HadoopRDD[K, V]( } } - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { // TODO: Filtering out "localhost" in case of file:// URLs val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index 23b9fb023b..5f4acee041 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -12,6 +12,6 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = f(firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index 41955c1d7a..f0f3f2c7c7 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -14,6 +14,6 @@ class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( f: (Int, Iterator[T]) => Iterator[U]) extends RDD[U](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = f(split.index, firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 6f8cb21fd3..44b542db93 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -9,6 +9,6 @@ class MappedRDD[U: ClassManifest, T: ClassManifest]( f: T => U) extends RDD[U](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).map(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index c12df5839e..91f89e3c75 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -55,7 +55,7 @@ class NewHadoopRDD[K, V]( result } - override def splits = splits_ + override def getSplits = splits_ override def compute(theSplit: Split) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopSplit] @@ -89,7 +89,7 @@ class NewHadoopRDD[K, V]( } } - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index d2047375ea..a88929e55e 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -29,7 +29,7 @@ class PipedRDD[T: ClassManifest]( // using a standard StringTokenizer (i.e. by spaces) def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split): Iterator[String] = { val pb = new ProcessBuilder(command) diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index c622e14a66..da6f65765c 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -26,9 +26,9 @@ class SampledRDD[T: ClassManifest]( firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } - override def splits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_.asInstanceOf[Array[Split]] - override def preferredLocations(split: Split) = + override def getPreferredLocations(split: Split) = firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) override def compute(splitIn: Split) = { @@ -51,8 +51,7 @@ class SampledRDD[T: ClassManifest]( } } - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index a9dd3f35ed..2caf33c21e 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -25,15 +25,14 @@ class ShuffledRDD[K, V]( @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index a84867492b..05ed6172d1 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -37,7 +37,7 @@ class UnionRDD[T: ClassManifest]( array } - override def splits = splits_ + override def getSplits = splits_ @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] @@ -49,19 +49,16 @@ class UnionRDD[T: ClassManifest]( deps.toList } - // Pre-checkpoint dependencies deps_ should be transient (deps_) - // but post-checkpoint dependencies must not be transient (dependencies_) - override def dependencies = if (isCheckpointed) dependencies_ else deps_ + override def getDependencies = deps_ override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() - override def preferredLocations(s: Split): Seq[String] = + override def getPreferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() - override def changeDependencies(newRDD: RDD[_]) { + override def clearDependencies() { deps_ = null - dependencies_ = List(new OneToOneDependency(newRDD)) - splits_ = newRDD.splits + splits_ = null rdds = null } } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 33d35b35d1..4b2570fa2b 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -575,7 +575,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those - val rddPrefs = rdd.getPreferredLocations(rdd.splits(partition)).toList + val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList if (rddPrefs != Nil) { return rddPrefs } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 0bffedb8db..19626d2450 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -57,7 +57,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.splits.length === numSplits) - assert(parCollection.splits.toList === parCollection.checkpointData.get.cpRDDSplits.toList) + assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList) assert(parCollection.collect() === result) } @@ -72,7 +72,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.splits.length === numSplits) - assert(blockRDD.splits.toList === blockRDD.checkpointData.get.cpRDDSplits.toList) + assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList) assert(blockRDD.collect() === result) } @@ -191,7 +191,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(operatedRDD.dependencies.head.rdd != parentRDD) // Test whether the splits have been changed to the new Hadoop splits - assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.cpRDDSplits.toList) + assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.getSplits.toList) // Test whether the number of splits is same as before assert(operatedRDD.splits.length === numSplits) -- cgit v1.2.3 From f9c5b0a6fe8d728e16c60c0cf51ced0054e3a387 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 20 Dec 2012 11:52:23 -0800 Subject: Changed checkpoint writing and reading process. --- core/src/main/scala/spark/RDDCheckpointData.scala | 27 +---- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 117 ++++++++++++++++++++++ core/src/main/scala/spark/rdd/HadoopRDD.scala | 5 +- 3 files changed, 124 insertions(+), 25 deletions(-) create mode 100644 core/src/main/scala/spark/rdd/CheckpointRDD.scala (limited to 'core') diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index e4c0912cdc..1aa9b9aa1e 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -1,7 +1,7 @@ package spark import org.apache.hadoop.fs.Path -import rdd.CoalescedRDD +import rdd.{CheckpointRDD, CoalescedRDD} import scheduler.{ResultTask, ShuffleMapTask} /** @@ -55,30 +55,13 @@ extends Logging with Serializable { } // Save to file, and reload it as an RDD - val file = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString - rdd.saveAsObjectFile(file) - - val newRDD = { - val hadoopRDD = rdd.context.objectFile[T](file, rdd.splits.size) - - val oldSplits = rdd.splits.size - val newSplits = hadoopRDD.splits.size - - logDebug("RDD splits = " + oldSplits + " --> " + newSplits) - if (newSplits < oldSplits) { - throw new Exception("# splits after checkpointing is less than before " + - "[" + oldSplits + " --> " + newSplits) - } else if (newSplits > oldSplits) { - new CoalescedRDD(hadoopRDD, rdd.splits.size) - } else { - hadoopRDD - } - } - logDebug("New RDD has " + newRDD.splits.size + " splits") + val path = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString + rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _) + val newRDD = new CheckpointRDD[T](rdd.context, path) // Change the dependencies and splits of the RDD RDDCheckpointData.synchronized { - cpFile = Some(file) + cpFile = Some(path) cpRDD = Some(newRDD) rdd.changeDependencies(newRDD) cpState = Checkpointed diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala new file mode 100644 index 0000000000..c673ab6aaa --- /dev/null +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -0,0 +1,117 @@ +package spark.rdd + +import spark._ +import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.{NullWritable, BytesWritable} +import org.apache.hadoop.util.ReflectionUtils +import org.apache.hadoop.fs.Path +import java.io.{File, IOException, EOFException} +import java.text.NumberFormat + +private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split { + override val index: Int = idx +} + +class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) + extends RDD[T](sc, Nil) { + + @transient val path = new Path(checkpointPath) + @transient val fs = path.getFileSystem(new Configuration()) + + @transient val splits_ : Array[Split] = { + val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted + splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray + } + + override def getSplits = splits_ + + override def getPreferredLocations(split: Split): Seq[String] = { + val status = fs.getFileStatus(path) + val locations = fs.getFileBlockLocations(status, 0, status.getLen) + locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost") + } + + override def compute(split: Split): Iterator[T] = { + CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile) + } + + override def checkpoint() { + // Do nothing. Hadoop RDD should not be checkpointed. + } +} + +private[spark] object CheckpointRDD extends Logging { + + def splitIdToFileName(splitId: Int): String = { + val numfmt = NumberFormat.getInstance() + numfmt.setMinimumIntegerDigits(5) + numfmt.setGroupingUsed(false) + "part-" + numfmt.format(splitId) + } + + def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) { + val outputDir = new Path(path) + val fs = outputDir.getFileSystem(new Configuration()) + + val finalOutputName = splitIdToFileName(context.splitId) + val finalOutputPath = new Path(outputDir, finalOutputName) + val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId) + + if (fs.exists(tempOutputPath)) { + throw new IOException("Checkpoint failed: temporary path " + + tempOutputPath + " already exists") + } + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + + val fileOutputStream = if (blockSize < 0) { + fs.create(tempOutputPath, false, bufferSize) + } else { + // This is mainly for testing purpose + fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + } + val serializer = SparkEnv.get.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + serializeStream.writeAll(iterator) + fileOutputStream.close() + + if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.delete(finalOutputPath, true)) { + throw new IOException("Checkpoint failed: failed to delete earlier output of task " + + context.attemptId); + } + if (!fs.rename(tempOutputPath, finalOutputPath)) { + throw new IOException("Checkpoint failed: failed to save output of task: " + + context.attemptId) + } + } + } + + def readFromFile[T](path: String): Iterator[T] = { + val inputPath = new Path(path) + val fs = inputPath.getFileSystem(new Configuration()) + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val fileInputStream = fs.open(inputPath, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + deserializeStream.asIterator.asInstanceOf[Iterator[T]] + } + + // Test whether CheckpointRDD generate expected number of splits despite + // each split file having multiple blocks. This needs to be run on a + // cluster (mesos or standalone) using HDFS. + def main(args: Array[String]) { + import spark._ + + val Array(cluster, hdfsPath) = args + val sc = new SparkContext(cluster, "CheckpointRDD Test") + val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) + val path = new Path(hdfsPath, "temp") + val fs = path.getFileSystem(new Configuration()) + sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 10) _) + val cpRDD = new CheckpointRDD[Int](sc, path.toString) + assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same") + assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same") + fs.delete(path) + } +} diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index fce190b860..eca51758e4 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -25,8 +25,7 @@ import spark.Split * A Spark split class that wraps around a Hadoop InputSplit. */ private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit) - extends Split - with Serializable { + extends Split { val inputSplit = new SerializableWritable[InputSplit](s) @@ -117,6 +116,6 @@ class HadoopRDD[K, V]( } override def checkpoint() { - // Do nothing. Hadoop RDD cannot be checkpointed. + // Do nothing. Hadoop RDD should not be checkpointed. } } -- cgit v1.2.3 From fe777eb77dee3c5bc5a7a332098d27f517ad3fe4 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 20 Dec 2012 13:39:27 -0800 Subject: Fixed bugs in CheckpointRDD and spark.CheckpointSuite. --- core/src/main/scala/spark/SparkContext.scala | 12 +++--------- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 3 +++ core/src/test/scala/spark/CheckpointSuite.scala | 6 +++--- 3 files changed, 9 insertions(+), 12 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 71ed4ef058..362aa04e66 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -37,9 +37,7 @@ import spark.broadcast._ import spark.deploy.LocalSparkCluster import spark.partial.ApproximateEvaluator import spark.partial.PartialResult -import spark.rdd.HadoopRDD -import spark.rdd.NewHadoopRDD -import spark.rdd.UnionRDD +import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD} import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} @@ -368,13 +366,9 @@ class SparkContext( protected[spark] def checkpointFile[T: ClassManifest]( - path: String, - minSplits: Int = defaultMinSplits + path: String ): RDD[T] = { - val rdd = objectFile[T](path, minSplits) - rdd.checkpointData = Some(new RDDCheckpointData(rdd)) - rdd.checkpointData.get.cpFile = Some(path) - rdd + new CheckpointRDD[T](this, path) } /** Build the union of a list of RDDs. */ diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index c673ab6aaa..fbf8a9ef83 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -24,6 +24,9 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray } + checkpointData = Some(new RDDCheckpointData[T](this)) + checkpointData.get.cpFile = Some(checkpointPath) + override def getSplits = splits_ override def getPreferredLocations(split: Split): Seq[String] = { diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 19626d2450..6bc667bd4c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -54,7 +54,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { parCollection.checkpoint() assert(parCollection.dependencies === Nil) val result = parCollection.collect() - assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) + assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.splits.length === numSplits) assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList) @@ -69,7 +69,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val numSplits = blockRDD.splits.size blockRDD.checkpoint() val result = blockRDD.collect() - assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) + assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.splits.length === numSplits) assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList) @@ -185,7 +185,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) // Test whether the checkpoint file has been created - assert(sc.objectFile[U](operatedRDD.getCheckpointFile.get).collect() === result) + assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result) // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD) -- cgit v1.2.3 From 4608902fb87af64a15b97ab21fe6382cd6e5a644 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 24 Dec 2012 17:20:10 -0800 Subject: Use filesystem to collect RDDs in PySpark. Passing large volumes of data through Py4J seems to be slow. It appears to be faster to write the data to the local filesystem and read it back from Python. --- .../main/scala/spark/api/python/PythonRDD.scala | 66 ++++++++-------------- pyspark/pyspark/context.py | 9 ++- pyspark/pyspark/rdd.py | 34 +++++++++-- pyspark/pyspark/serializers.py | 8 +++ pyspark/pyspark/worker.py | 12 +--- 5 files changed, 66 insertions(+), 63 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 50094d6b0f..4f870e837a 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -1,6 +1,7 @@ package spark.api.python import java.io._ +import java.util.{List => JList} import scala.collection.Map import scala.collection.JavaConversions._ @@ -59,36 +60,7 @@ trait PythonRDDBase { } out.flush() for (elem <- parent.iterator(split)) { - if (elem.isInstanceOf[Array[Byte]]) { - val arr = elem.asInstanceOf[Array[Byte]] - dOut.writeInt(arr.length) - dOut.write(arr) - } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { - val t = elem.asInstanceOf[scala.Tuple2[_, _]] - val t1 = t._1.asInstanceOf[Array[Byte]] - val t2 = t._2.asInstanceOf[Array[Byte]] - val length = t1.length + t2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t1)) - dOut.write(PythonRDD.stripPickle(t2)) - dOut.writeByte(Pickle.TUPLE2) - dOut.writeByte(Pickle.STOP) - } else if (elem.isInstanceOf[String]) { - // For uniformity, strings are wrapped into Pickles. - val s = elem.asInstanceOf[String].getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.writeByte(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } else { - throw new Exception("Unexpected RDD type") - } + PythonRDD.writeAsPickle(elem, dOut) } dOut.flush() out.flush() @@ -174,36 +146,45 @@ object PythonRDD { arr.slice(2, arr.length - 1) } - def asPickle(elem: Any) : Array[Byte] = { - val baos = new ByteArrayOutputStream(); - val dOut = new DataOutputStream(baos); + /** + * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. + * The data format is a 32-bit integer representing the pickled object's length (in bytes), + * followed by the pickled data. + * @param elem the object to write + * @param dOut a data output stream + */ + def writeAsPickle(elem: Any, dOut: DataOutputStream) { if (elem.isInstanceOf[Array[Byte]]) { - elem.asInstanceOf[Array[Byte]] + val arr = elem.asInstanceOf[Array[Byte]] + dOut.writeInt(arr.length) + dOut.write(arr) } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] + val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes + dOut.writeInt(length) dOut.writeByte(Pickle.PROTO) dOut.writeByte(Pickle.TWO) dOut.write(PythonRDD.stripPickle(t._1)) dOut.write(PythonRDD.stripPickle(t._2)) dOut.writeByte(Pickle.TUPLE2) dOut.writeByte(Pickle.STOP) - baos.toByteArray() } else if (elem.isInstanceOf[String]) { // For uniformity, strings are wrapped into Pickles. val s = elem.asInstanceOf[String].getBytes("UTF-8") + val length = 2 + 1 + 4 + s.length + 1 + dOut.writeInt(length) dOut.writeByte(Pickle.PROTO) dOut.writeByte(Pickle.TWO) dOut.write(Pickle.BINUNICODE) dOut.writeInt(Integer.reverseBytes(s.length)) dOut.write(s) dOut.writeByte(Pickle.STOP) - baos.toByteArray() } else { throw new Exception("Unexpected RDD type") } } - def pickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) val objs = new collection.mutable.ArrayBuffer[Array[Byte]] @@ -221,11 +202,12 @@ object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def arrayAsPickle(arr : Any) : Array[Byte] = { - val pickles : Array[Byte] = arr.asInstanceOf[Array[Any]].map(asPickle).map(stripPickle).flatten - - Array[Byte](Pickle.PROTO, Pickle.TWO, Pickle.EMPTY_LIST, Pickle.MARK) ++ pickles ++ - Array[Byte] (Pickle.APPENDS, Pickle.STOP) + def writeArrayToPickleFile[T](items: Array[T], filename: String) { + val file = new DataOutputStream(new FileOutputStream(filename)) + for (item <- items) { + writeAsPickle(item, file) + } + file.close() } } diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 50d57e5317..19f9f9e133 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -14,9 +14,8 @@ class SparkContext(object): gateway = launch_gateway() jvm = gateway.jvm - pickleFile = jvm.spark.api.python.PythonRDD.pickleFile - asPickle = jvm.spark.api.python.PythonRDD.asPickle - arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle + readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile + writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile def __init__(self, master, name, defaultParallelism=None): self.master = master @@ -45,11 +44,11 @@ class SparkContext(object): # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) + atexit.register(lambda: os.unlink(tempFile.name)) for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - atexit.register(lambda: os.unlink(tempFile.name)) - jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) + jrdd = self.readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 708ea6eb55..01908cff96 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,13 +1,15 @@ +import atexit from base64 import standard_b64encode as b64enc from collections import defaultdict from itertools import chain, ifilter, imap import os import shlex from subprocess import Popen, PIPE +from tempfile import NamedTemporaryFile from threading import Thread from pyspark import cloudpickle -from pyspark.serializers import dump_pickle, load_pickle +from pyspark.serializers import dump_pickle, load_pickle, read_from_pickle_file from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup @@ -145,10 +147,30 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): + # To minimize the number of transfers between Python and Java, we'll + # flatten each partition into a list before collecting it. Due to + # pipelining, this should add minimal overhead. def asList(iterator): yield list(iterator) - pickles = self.mapPartitions(asList)._jrdd.rdd().collect() - return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles)) + picklesInJava = self.mapPartitions(asList)._jrdd.rdd().collect() + return list(chain.from_iterable(self._collect_array_through_file(picklesInJava))) + + def _collect_array_through_file(self, array): + # Transferring lots of data through Py4J can be slow because + # socket.readline() is inefficient. Instead, we'll dump the data to a + # file and read it back. + tempFile = NamedTemporaryFile(delete=False) + tempFile.close() + def clean_up_file(): + try: os.unlink(tempFile.name) + except: pass + atexit.register(clean_up_file) + self.ctx.writeArrayToPickleFile(array, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + for item in read_from_pickle_file(tempFile): + yield item + os.unlink(tempFile.name) def reduce(self, f): """ @@ -220,15 +242,15 @@ class RDD(object): >>> sc.parallelize([2, 3, 4]).take(2) [2, 3] """ - pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num)) - return load_pickle(bytes(pickle)) + picklesInJava = self._jrdd.rdd().take(num) + return list(self._collect_array_through_file(picklesInJava)) def first(self): """ >>> sc.parallelize([2, 3, 4]).first() 2 """ - return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first()))) + return self.take(1)[0] def saveAsTextFile(self, path): def func(iterator): diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 21ef8b106c..bfcdda8f12 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -33,3 +33,11 @@ def read_with_length(stream): if obj == "": raise EOFError return obj + + +def read_from_pickle_file(stream): + try: + while True: + yield load_pickle(read_with_length(stream)) + except EOFError: + return diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 62824a1c9b..9f6b507dbd 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -8,7 +8,7 @@ from base64 import standard_b64decode from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.serializers import write_with_length, read_with_length, \ - read_long, read_int, dump_pickle, load_pickle + read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file # Redirect stdout to stderr so that users must return values from functions. @@ -20,14 +20,6 @@ def load_obj(): return load_pickle(standard_b64decode(sys.stdin.readline().strip())) -def read_input(): - try: - while True: - yield load_pickle(read_with_length(sys.stdin)) - except EOFError: - return - - def main(): num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): @@ -40,7 +32,7 @@ def main(): dumps = lambda x: x else: dumps = dump_pickle - for obj in func(read_input()): + for obj in func(read_from_pickle_file(sys.stdin)): write_with_length(dumps(obj), old_stdout) -- cgit v1.2.3 From 1dca0c51804b9c94709ec9cc0544b8dfb7afe59f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 26 Dec 2012 18:23:06 -0800 Subject: Remove debug output from PythonPartitioner. --- core/src/main/scala/spark/api/python/PythonPartitioner.scala | 2 -- 1 file changed, 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala index ef9f808fb2..606a80d1eb 100644 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -16,8 +16,6 @@ class PythonPartitioner(override val numPartitions: Int) extends Partitioner { else { val hashCode = { if (key.isInstanceOf[Array[Byte]]) { - System.err.println("Dumping a byte array!" + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) - ) Arrays.hashCode(key.asInstanceOf[Array[Byte]]) } else -- cgit v1.2.3 From 0bc0a60d3001dd231e13057a838d4b6550e5a2b9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 27 Dec 2012 15:37:33 -0800 Subject: Modifications to make sure LocalScheduler terminate cleanly without errors when SparkContext is shutdown, to minimize spurious exception during master failure tests. --- core/src/main/scala/spark/SparkContext.scala | 22 ++++++++++++---------- .../spark/scheduler/local/LocalScheduler.scala | 8 ++++++-- core/src/test/resources/log4j.properties | 2 +- .../src/test/scala/spark/ClosureCleanerSuite.scala | 2 ++ streaming/src/test/resources/log4j.properties | 13 ++++++++----- 5 files changed, 29 insertions(+), 18 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index caa9a1794b..0c8b0078a3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -488,17 +488,19 @@ class SparkContext( if (dagScheduler != null) { dagScheduler.stop() dagScheduler = null + taskScheduler = null + // TODO: Cache.stop()? + env.stop() + // Clean up locally linked files + clearFiles() + clearJars() + SparkEnv.set(null) + ShuffleMapTask.clearCache() + ResultTask.clearCache() + logInfo("Successfully stopped SparkContext") + } else { + logInfo("SparkContext already stopped") } - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - // Clean up locally linked files - clearFiles() - clearJars() - SparkEnv.set(null) - ShuffleMapTask.clearCache() - ResultTask.clearCache() - logInfo("Successfully stopped SparkContext") } /** diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index eb20fe41b2..17a0a4b103 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -81,7 +81,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( ser.serialize(Accumulators.values)) logInfo("Finished task " + idInJob) - listener.taskEnded(task, Success, resultToReturn, accumUpdates) + + // If the threadpool has not already been shutdown, notify DAGScheduler + if (!Thread.currentThread().isInterrupted) + listener.taskEnded(task, Success, resultToReturn, accumUpdates) } catch { case t: Throwable => { logError("Exception in task " + idInJob, t) @@ -91,7 +94,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon submitTask(task, idInJob) } else { // TODO: Do something nicer here to return all the way to the user - listener.taskEnded(task, new ExceptionFailure(t), null, null) + if (!Thread.currentThread().isInterrupted) + listener.taskEnded(task, new ExceptionFailure(t), null, null) } } } diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index 4c99e450bc..5ed388e91b 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -1,4 +1,4 @@ -# Set everything to be logged to the console +# Set everything to be logged to the file spark-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala index 7c0334d957..dfa2de80e6 100644 --- a/core/src/test/scala/spark/ClosureCleanerSuite.scala +++ b/core/src/test/scala/spark/ClosureCleanerSuite.scala @@ -47,6 +47,8 @@ object TestObject { val nums = sc.parallelize(Array(1, 2, 3, 4)) val answer = nums.map(_ + x).reduce(_ + _) sc.stop() + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") return answer } } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 02fe16866e..33bafebaab 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -1,8 +1,11 @@ -# Set everything to be logged to the console -log4j.rootCategory=WARN, console -log4j.appender.console=org.apache.log4j.ConsoleAppender -log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +# Set everything to be logged to the file streaming-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=streaming-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN + -- cgit v1.2.3 From fbadb1cda504b256e3d12c4ce389e723b6f2503c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 28 Dec 2012 09:06:11 -0800 Subject: Mark api.python classes as private; echo Java output to stderr. --- .../scala/spark/api/python/PythonPartitioner.scala | 2 +- .../main/scala/spark/api/python/PythonRDD.scala | 50 +++++++++------------- pyspark/pyspark/java_gateway.py | 3 +- 3 files changed, 24 insertions(+), 31 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala index 606a80d1eb..2c829508e5 100644 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -7,7 +7,7 @@ import java.util.Arrays /** * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. */ -class PythonPartitioner(override val numPartitions: Int) extends Partitioner { +private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner { override def getPartition(key: Any): Int = { if (key == null) { diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 4f870e837a..a80a8eea45 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -3,7 +3,6 @@ package spark.api.python import java.io._ import java.util.{List => JList} -import scala.collection.Map import scala.collection.JavaConversions._ import scala.io.Source @@ -16,10 +15,26 @@ import spark.OneToOneDependency import spark.rdd.PipedRDD -trait PythonRDDBase { - def compute[T](split: Split, envVars: Map[String, String], - command: Seq[String], parent: RDD[T], pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]): Iterator[Array[Byte]] = { +private[spark] class PythonRDD[T: ClassManifest]( + parent: RDD[T], command: Seq[String], envVars: java.util.Map[String, String], + preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) + extends RDD[Array[Byte]](parent.context) { + + // Similar to Runtime.exec(), if we are given a single string, split it into words + // using a standard StringTokenizer (i.e. by spaces) + def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], + preservePartitoning: Boolean, pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, + broadcastVars) + + override def splits = parent.splits + + override val dependencies = List(new OneToOneDependency(parent)) + + override val partitioner = if (preservePartitoning) parent.partitioner else None + + override def compute(split: Split): Iterator[Array[Byte]] = { val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) @@ -100,29 +115,6 @@ trait PythonRDDBase { def hasNext = _nextObj.length != 0 } } -} - -class PythonRDD[T: ClassManifest]( - parent: RDD[T], command: Seq[String], envVars: java.util.Map[String, String], - preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) - extends RDD[Array[Byte]](parent.context) with PythonRDDBase { - - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, - broadcastVars) - - override def splits = parent.splits - - override val dependencies = List(new OneToOneDependency(parent)) - - override val partitioner = if (preservePartitoning) parent.partitioner else None - - override def compute(split: Split): Iterator[Array[Byte]] = - compute(split, envVars.toMap, command, parent, pythonExec, broadcastVars) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } @@ -139,7 +131,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } -object PythonRDD { +private[spark] object PythonRDD { /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ def stripPickle(arr: Array[Byte]) : Array[Byte] = { diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py index d4a4434c05..eb2a875762 100644 --- a/pyspark/pyspark/java_gateway.py +++ b/pyspark/pyspark/java_gateway.py @@ -1,4 +1,5 @@ import os +import sys from subprocess import Popen, PIPE from threading import Thread from py4j.java_gateway import java_import, JavaGateway, GatewayClient @@ -26,7 +27,7 @@ def launch_gateway(): def run(self): while True: line = self.stream.readline() - print line, + sys.stderr.write(line) EchoOutputThread(proc.stdout).start() # Connect to the gateway gateway = JavaGateway(GatewayClient(port=port)) -- cgit v1.2.3 From f1bf4f0385a8e5da14a1d4b01bbbea17b98c4aa3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 28 Dec 2012 16:13:23 -0800 Subject: Skip deletion of files in clearFiles(). This fixes an issue where Spark could delete original files in the current working directory that were added to the job using addFile(). There was also the potential for addFile() to overwrite local files, which is addressed by changing Utils.fetchFile() to log a warning instead of overwriting a file with new contents. This is a short-term fix; a better long-term solution would be to remove the dependence on storing files in the current working directory, since we can't change the cwd from Java. --- core/src/main/scala/spark/SparkContext.scala | 9 ++--- core/src/main/scala/spark/Utils.scala | 57 ++++++++++++++++++++-------- 2 files changed, 46 insertions(+), 20 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 0afab522af..4fd81bc63b 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -419,8 +419,9 @@ class SparkContext( } addedFiles(key) = System.currentTimeMillis - // Fetch the file locally in case the task is executed locally - val filename = new File(path.split("/").last) + // Fetch the file locally in case a job is executed locally. + // Jobs that run through LocalScheduler will already fetch the required dependencies, + // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. Utils.fetchFile(path, new File(".")) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) @@ -437,11 +438,10 @@ class SparkContext( } /** - * Clear the job's list of files added by `addFile` so that they do not get donwloaded to + * Clear the job's list of files added by `addFile` so that they do not get downloaded to * any new nodes. */ def clearFiles() { - addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() } addedFiles.clear() } @@ -465,7 +465,6 @@ class SparkContext( * any new nodes. */ def clearJars() { - addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() } addedJars.clear() } diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 6d64b32174..c10b415a93 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -9,6 +9,7 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.io.Source +import com.google.common.io.Files /** * Various utility methods used by Spark. @@ -130,28 +131,47 @@ private object Utils extends Logging { */ def fetchFile(url: String, targetDir: File) { val filename = url.split("/").last + val tempDir = System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")) + val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) val targetFile = new File(targetDir, filename) val uri = new URI(url) uri.getScheme match { case "http" | "https" | "ftp" => - logInfo("Fetching " + url + " to " + targetFile) + logInfo("Fetching " + url + " to " + tempFile) val in = new URL(url).openStream() - val out = new FileOutputStream(targetFile) + val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) + if (targetFile.exists && !Files.equal(tempFile, targetFile)) { + logWarning("File " + targetFile + " exists and does not match contents of " + url + + "; using existing version") + tempFile.delete() + } else { + Files.move(tempFile, targetFile) + } case "file" | null => - // Remove the file if it already exists - targetFile.delete() - // Symlink the file locally. - if (uri.isAbsolute) { - // url is absolute, i.e. it starts with "file:///". Extract the source - // file's absolute path from the url. - val sourceFile = new File(uri) - logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) - FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath) + val sourceFile = if (uri.isAbsolute) { + new File(uri) + } else { + new File(url) + } + if (targetFile.exists && !Files.equal(sourceFile, targetFile)) { + logWarning("File " + targetFile + " exists and does not match contents of " + url + + "; using existing version") } else { - // url is not absolute, i.e. itself is the path to the source file. - logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath) - FileUtil.symLink(url, targetFile.getAbsolutePath) + // Remove the file if it already exists + targetFile.delete() + // Symlink the file locally. + if (uri.isAbsolute) { + // url is absolute, i.e. it starts with "file:///". Extract the source + // file's absolute path from the url. + val sourceFile = new File(uri) + logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) + FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath) + } else { + // url is not absolute, i.e. itself is the path to the source file. + logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath) + FileUtil.symLink(url, targetFile.getAbsolutePath) + } } case _ => // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others @@ -159,8 +179,15 @@ private object Utils extends Logging { val conf = new Configuration() val fs = FileSystem.get(uri, conf) val in = fs.open(new Path(uri)) - val out = new FileOutputStream(targetFile) + val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) + if (targetFile.exists && !Files.equal(tempFile, targetFile)) { + logWarning("File " + targetFile + " exists and does not match contents of " + url + + "; using existing version") + tempFile.delete() + } else { + Files.move(tempFile, targetFile) + } } // Decompress the file if it's a .tar or .tar.gz if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { -- cgit v1.2.3 From bd237d4a9d7f08eb143b2a2b8636a6a8453225ea Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 28 Dec 2012 16:14:36 -0800 Subject: Add synchronization to LocalScheduler.updateDependencies(). --- .../spark/scheduler/local/LocalScheduler.scala | 34 ++++++++++++---------- 1 file changed, 18 insertions(+), 16 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index eb20fe41b2..5d927efb65 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -108,22 +108,24 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) - currentFiles(name) = timestamp - } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(".", localName).toURI.toURL - if (!classLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - classLoader.addURL(url) + this.synchronized { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(".")) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(".")) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(".", localName).toURI.toURL + if (!classLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + classLoader.addURL(url) + } } } } -- cgit v1.2.3 From d64fa72d2e4a8290d15e65459337f544e55b3b48 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 28 Dec 2012 16:20:38 -0800 Subject: Add addFile() and addJar() to JavaSparkContext. --- .../scala/spark/api/java/JavaSparkContext.scala | 34 ++++++++++++++++++++++ 1 file changed, 34 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index edbb187b1b..b7725313c4 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -301,6 +301,40 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * (in that order of preference). If neither of these is set, return None. */ def getSparkHome(): Option[String] = sc.getSparkHome() + + /** + * Add a file to be downloaded into the working directory of this Spark job on every node. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. + */ + def addFile(path: String) { + sc.addFile(path) + } + + /** + * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. + */ + def addJar(path: String) { + sc.addJar(path) + } + + /** + * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to + * any new nodes. + */ + def clearJars() { + sc.clearJars() + } + + /** + * Clear the job's list of files added by `addFile` so that they do not get downloaded to + * any new nodes. + */ + def clearFiles() { + sc.clearFiles() + } } object JavaSparkContext { -- cgit v1.2.3 From 397e67103c18ba22c8c63e9692f0096cd0094797 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 28 Dec 2012 17:37:13 -0800 Subject: Change Utils.fetchFile() warning to SparkException. --- core/src/main/scala/spark/Utils.scala | 15 +++++++++------ .../main/scala/spark/scheduler/local/LocalScheduler.scala | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index c10b415a93..0e7007459d 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -128,6 +128,9 @@ private object Utils extends Logging { /** * Download a file requested by the executor. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + * + * Throws SparkException if the target file already exists and has different contents than + * the requested file. */ def fetchFile(url: String, targetDir: File) { val filename = url.split("/").last @@ -142,9 +145,9 @@ private object Utils extends Logging { val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) if (targetFile.exists && !Files.equal(tempFile, targetFile)) { - logWarning("File " + targetFile + " exists and does not match contents of " + url + - "; using existing version") tempFile.delete() + throw new SparkException("File " + targetFile + " exists and does not match contents of" + + " " + url) } else { Files.move(tempFile, targetFile) } @@ -155,8 +158,8 @@ private object Utils extends Logging { new File(url) } if (targetFile.exists && !Files.equal(sourceFile, targetFile)) { - logWarning("File " + targetFile + " exists and does not match contents of " + url + - "; using existing version") + throw new SparkException("File " + targetFile + " exists and does not match contents of" + + " " + url) } else { // Remove the file if it already exists targetFile.delete() @@ -182,9 +185,9 @@ private object Utils extends Logging { val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) if (targetFile.exists && !Files.equal(tempFile, targetFile)) { - logWarning("File " + targetFile + " exists and does not match contents of " + url + - "; using existing version") tempFile.delete() + throw new SparkException("File " + targetFile + " exists and does not match contents of" + + " " + url) } else { Files.move(tempFile, targetFile) } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 5d927efb65..2593c0e3a0 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -108,7 +108,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - this.synchronized { + synchronized { // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) -- cgit v1.2.3 From 7ec3595de28d53839cb3a45e940ec16f81ffdf45 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 28 Dec 2012 22:19:12 -0800 Subject: Fix bug (introduced by batching) in PySpark take() --- .../main/scala/spark/api/python/PythonRDD.scala | 2 +- pyspark/pyspark/context.py | 6 ++--- pyspark/pyspark/java_gateway.py | 2 +- pyspark/pyspark/rdd.py | 27 ++++++++++++++-------- 4 files changed, 22 insertions(+), 15 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index a80a8eea45..f76616a4c4 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -194,7 +194,7 @@ private[spark] object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def writeArrayToPickleFile[T](items: Array[T], filename: String) { + def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) for (item <- items) { writeAsPickle(item, file) diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 988c81cd5d..b90596ecc2 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -19,8 +19,8 @@ class SparkContext(object): gateway = launch_gateway() jvm = gateway.jvm - readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile - writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile + _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile + _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -94,7 +94,7 @@ class SparkContext(object): for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - jrdd = self.readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) + jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py index eb2a875762..2329e536cc 100644 --- a/pyspark/pyspark/java_gateway.py +++ b/pyspark/pyspark/java_gateway.py @@ -30,7 +30,7 @@ def launch_gateway(): sys.stderr.write(line) EchoOutputThread(proc.stdout).start() # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=port)) + gateway = JavaGateway(GatewayClient(port=port), auto_convert=False) # Import the classes used by PySpark java_import(gateway.jvm, "spark.api.java.*") java_import(gateway.jvm, "spark.api.python.*") diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index bf32472d25..111476d274 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -152,8 +152,8 @@ class RDD(object): into a list. >>> rdd = sc.parallelize([1, 2, 3, 4], 2) - >>> rdd.glom().first() - [1, 2] + >>> sorted(rdd.glom().collect()) + [[1, 2], [3, 4]] """ def func(iterator): yield list(iterator) return self.mapPartitions(func) @@ -211,10 +211,10 @@ class RDD(object): """ Return a list that contains all of the elements in this RDD. """ - picklesInJava = self._jrdd.rdd().collect() - return list(self._collect_array_through_file(picklesInJava)) + picklesInJava = self._jrdd.collect().iterator() + return list(self._collect_iterator_through_file(picklesInJava)) - def _collect_array_through_file(self, array): + def _collect_iterator_through_file(self, iterator): # Transferring lots of data through Py4J can be slow because # socket.readline() is inefficient. Instead, we'll dump the data to a # file and read it back. @@ -224,7 +224,7 @@ class RDD(object): try: os.unlink(tempFile.name) except: pass atexit.register(clean_up_file) - self.ctx.writeArrayToPickleFile(array, tempFile.name) + self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: for item in read_from_pickle_file(tempFile): @@ -325,11 +325,18 @@ class RDD(object): a lot of partitions are required. In that case, use L{collect} to get the whole RDD instead. - >>> sc.parallelize([2, 3, 4]).take(2) + >>> sc.parallelize([2, 3, 4, 5, 6]).take(2) [2, 3] - """ - picklesInJava = self._jrdd.rdd().take(num) - return list(self._collect_array_through_file(picklesInJava)) + >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) + [2, 3, 4, 5, 6] + """ + items = [] + splits = self._jrdd.splits() + while len(items) < num and splits: + split = splits.pop(0) + iterator = self._jrdd.iterator(split) + items.extend(self._collect_iterator_through_file(iterator)) + return items[:num] def first(self): """ -- cgit v1.2.3 From 59195c68ec37acf20d527189ed757397b273a207 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Dec 2012 16:01:03 -0800 Subject: Update PySpark for compatibility with TaskContext. --- core/src/main/scala/spark/api/python/PythonRDD.scala | 13 +++++-------- pyspark/pyspark/rdd.py | 3 ++- 2 files changed, 7 insertions(+), 9 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index f76616a4c4..dc48378fdc 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -8,10 +8,7 @@ import scala.io.Source import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast -import spark.SparkEnv -import spark.Split -import spark.RDD -import spark.OneToOneDependency +import spark._ import spark.rdd.PipedRDD @@ -34,7 +31,7 @@ private[spark] class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Split): Iterator[Array[Byte]] = { + override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = { val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) @@ -74,7 +71,7 @@ private[spark] class PythonRDD[T: ClassManifest]( out.println(elem) } out.flush() - for (elem <- parent.iterator(split)) { + for (elem <- parent.iterator(split, context)) { PythonRDD.writeAsPickle(elem, dOut) } dOut.flush() @@ -123,8 +120,8 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Array[Byte], Array[Byte])](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = - prev.iterator(split).grouped(2).map { + override def compute(split: Split, context: TaskContext) = + prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 203f7377d2..21dda31c4e 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -335,9 +335,10 @@ class RDD(object): """ items = [] splits = self._jrdd.splits() + taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) while len(items) < num and splits: split = splits.pop(0) - iterator = self._jrdd.iterator(split) + iterator = self._jrdd.iterator(split, taskContext) items.extend(self._collect_iterator_through_file(iterator)) return items[:num] -- cgit v1.2.3 From 9e644402c155b5fc68794a17c36ddd19d3242f4f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 29 Dec 2012 18:31:51 -0800 Subject: Improved jekyll and scala docs. Made many classes and method private to remove them from scala docs. --- core/src/main/scala/spark/RDD.scala | 1 - docs/_plugins/copy_api_dirs.rb | 4 +- docs/streaming-programming-guide.md | 56 ++--- .../main/scala/spark/streaming/Checkpoint.scala | 5 +- .../src/main/scala/spark/streaming/DStream.scala | 249 +++++++++++++-------- .../scala/spark/streaming/FlumeInputDStream.scala | 2 +- .../src/main/scala/spark/streaming/Interval.scala | 1 + streaming/src/main/scala/spark/streaming/Job.scala | 2 + .../main/scala/spark/streaming/JobManager.scala | 1 + .../spark/streaming/NetworkInputDStream.scala | 8 +- .../spark/streaming/NetworkInputTracker.scala | 8 +- .../spark/streaming/PairDStreamFunctions.scala | 4 +- .../src/main/scala/spark/streaming/Scheduler.scala | 7 +- .../scala/spark/streaming/StreamingContext.scala | 43 ++-- .../scala/spark/streaming/examples/GrepRaw.scala | 2 +- .../streaming/examples/TopKWordCountRaw.scala | 2 +- .../spark/streaming/examples/WordCountRaw.scala | 2 +- .../examples/clickstream/PageViewStream.scala | 2 +- .../test/scala/spark/streaming/TestSuiteBase.scala | 2 +- 19 files changed, 233 insertions(+), 168 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 59e50a0b6b..1574533430 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -101,7 +101,6 @@ abstract class RDD[T: ClassManifest]( val partitioner: Option[Partitioner] = None - // ======================================================================= // Methods and fields available on all RDDs // ======================================================================= diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index e61c105449..7654511eeb 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -2,7 +2,7 @@ require 'fileutils' include FileUtils if ENV['SKIP_SCALADOC'] != '1' - projects = ["core", "examples", "repl", "bagel"] + projects = ["core", "examples", "repl", "bagel", "streaming"] puts "Moving to project root and building scaladoc." curr_dir = pwd @@ -11,7 +11,7 @@ if ENV['SKIP_SCALADOC'] != '1' puts "Running sbt/sbt doc from " + pwd + "; this may take a few minutes..." puts `sbt/sbt doc` - puts "moving back into docs dir." + puts "Moving back into docs dir." cd("docs") # Copy over the scaladoc from each project into the docs directory. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 90916545bc..7c421ac70f 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2,33 +2,44 @@ layout: global title: Streaming (Alpha) Programming Guide --- + +{:toc} + +# Overview +A Spark Streaming application is very similar to a Spark application; it consists of a *driver program* that runs the user's `main` function and continuous executes various *parallel operations* on input streams of data. The main abstraction Spark Streaming provides is a *discretized stream* (DStream), which is a continuous sequence of RDDs (distributed collection of elements) representing a continuous stream of data. DStreams can created from live incoming data (such as data from a socket, Kafka, etc.) or it can be generated by transformation of existing DStreams using parallel operators like map, reduce, and window. The basic processing model is as follows: +(i) While a Spark Streaming driver program is running, the system receives data from various sources and and divides the data into batches. Each batch of data is treated as a RDD, that is a immutable and parallel collection of data. These input data RDDs are automatically persisted in memory (serialized by default) and replicated to two nodes for fault-tolerance. This sequence of RDDs is collectively referred to as an InputDStream. +(ii) Data received by InputDStreams are processed processed using DStream operations. Since all data is represented as RDDs and all DStream operations as RDD operations, data is automatically recovered in the event of node failures. + +This guide shows some how to start programming with DStreams. + # Initializing Spark Streaming The first thing a Spark Streaming program must do is create a `StreamingContext` object, which tells Spark how to access a cluster. A `StreamingContext` can be created from an existing `SparkContext`, or directly: {% highlight scala %} -new StreamingContext(master, jobName, [sparkHome], [jars]) -new StreamingContext(sparkContext) -{% endhighlight %} - -Once a context is instantiated, the batch interval must be set: +import spark.SparkContext +import SparkContext._ -{% highlight scala %} -context.setBatchDuration(Milliseconds(2000)) +new StreamingContext(master, frameworkName, batchDuration) +new StreamingContext(sparkContext, batchDuration) {% endhighlight %} +The `master` parameter is either the [Mesos master URL](running-on-mesos.html) (for running on a cluster)or the special "local" string (for local mode) that is used to create a Spark Context. For more information about this please refer to the [Spark programming guide](scala-programming-guide.html). -# DStreams - Discretized Streams -The primary abstraction in Spark Streaming is a DStream. A DStream represents distributed collection which is computed periodically according to a specified batch interval. DStream's can be chained together to create complex chains of transformation on streaming data. DStreams can be created by operating on existing DStreams or from an input source. To creating DStreams from an input source, use the StreamingContext: + +# Creating Input Sources - InputDStreams +The StreamingContext is used to creating InputDStreams from input sources: {% highlight scala %} -context.neworkStream(host, port) // A stream that reads from a socket -context.flumeStream(hosts, ports) // A stream populated by a Flume flow +context.neworkStream(host, port) // Creates a stream that uses a TCP socket to read data from : +context.flumeStream(host, port) // Creates a stream populated by a Flume flow {% endhighlight %} -# DStream Operators +A complete list of input sources is available in the [DStream API doc](api/streaming/index.html#spark.streaming.StreamingContext). + +## DStream Operations Once an input stream has been created, you can transform it using _stream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the stream by writing data out to an external source. -## Transformations +### Transformations DStreams support many of the transformations available on normal Spark RDD's: @@ -73,20 +84,13 @@ DStreams support many of the transformations available on normal Spark RDD's: cogroup(otherStream, [numTasks]) When called on streams of type (K, V) and (K, W), returns a stream of (K, Seq[V], Seq[W]) tuples. This operation is also called groupWith. - - -DStreams also support the following additional transformations: - -
reduce(func) Create a new single-element stream by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel.
- -## Windowed Transformations -Spark streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowTime, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. +Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowTime, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. @@ -128,7 +132,7 @@ Spark streaming features windowed computations, which allow you to report statis
TransformationMeaning
-## Output Operators +### Output Operators When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined: @@ -140,22 +144,22 @@ When an output operator is called, it triggers the computation of a stream. Curr - + - + - + - + diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 770f7b0cc0..11a7232d7b 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -8,6 +8,7 @@ import org.apache.hadoop.conf.Configuration import java.io._ +private[streaming] class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Logging with Serializable { val master = ssc.sc.master @@ -30,6 +31,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) /** * Convenience class to speed up the writing of graph checkpoint to file */ +private[streaming] class CheckpointWriter(checkpointDir: String) extends Logging { val file = new Path(checkpointDir, "graph") val conf = new Configuration() @@ -65,7 +67,7 @@ class CheckpointWriter(checkpointDir: String) extends Logging { } - +private[streaming] object CheckpointReader extends Logging { def read(path: String): Checkpoint = { @@ -103,6 +105,7 @@ object CheckpointReader extends Logging { } } +private[streaming] class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) extends ObjectInputStream(inputStream_) { override def resolveClass(desc: ObjectStreamClass): Class[_] = { try { diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index d5048aeed7..3834b57ed3 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.conf.Configuration * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous * sequence of RDDs (of the same type) representing a continuous stream of data (see [[spark.RDD]] * for more details on RDDs). DStreams can either be created from live data (such as, data from - * HDFS. Kafka or Flume) or it can be generated by transformation existing DStreams using operations + * HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations * such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each * DStream periodically generates a RDD, either from live data or by transforming the RDD generated * by a parent DStream. @@ -38,33 +38,28 @@ import org.apache.hadoop.conf.Configuration * - A function that is used to generate an RDD after each time interval */ -case class DStreamCheckpointData(rdds: HashMap[Time, Any]) - -abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) -extends Serializable with Logging { +abstract class DStream[T: ClassManifest] ( + @transient protected[streaming] var ssc: StreamingContext + ) extends Serializable with Logging { initLogging() - /** - * ---------------------------------------------- - * Methods that must be implemented by subclasses - * ---------------------------------------------- - */ + // ======================================================================= + // Methods that should be implemented by subclasses of DStream + // ======================================================================= - // Time interval at which the DStream generates an RDD + /** Time interval after which the DStream generates a RDD */ def slideTime: Time - // List of parent DStreams on which this DStream depends on + /** List of parent DStreams on which this DStream depends on */ def dependencies: List[DStream[_]] - // Key method that computes RDD for a valid time + /** Method that generates a RDD for the given time */ def compute (validTime: Time): Option[RDD[T]] - /** - * --------------------------------------- - * Other general fields and methods of DStream - * --------------------------------------- - */ + // ======================================================================= + // Methods and fields available on all DStreams + // ======================================================================= // RDDs generated, marked as protected[streaming] so that testsuites can access it @transient @@ -87,12 +82,15 @@ extends Serializable with Logging { // Reference to whole DStream graph protected[streaming] var graph: DStreamGraph = null - def isInitialized = (zeroTime != null) + protected[streaming] def isInitialized = (zeroTime != null) // Duration for which the DStream requires its parent DStream to remember each RDD created - def parentRememberDuration = rememberDuration + protected[streaming] def parentRememberDuration = rememberDuration + + /** Returns the StreamingContext associated with this DStream */ + def context() = ssc - // Set caching level for the RDDs created by this DStream + /** Persists the RDDs of this DStream with the given storage level */ def persist(level: StorageLevel): DStream[T] = { if (this.isInitialized) { throw new UnsupportedOperationException( @@ -102,11 +100,16 @@ extends Serializable with Logging { this } + /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY_SER) - - // Turn on the default caching level for this RDD + + /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def cache(): DStream[T] = persist() + /** + * Enable periodic checkpointing of RDDs of this DStream + * @param interval Time interval after which generated RDD will be checkpointed + */ def checkpoint(interval: Time): DStream[T] = { if (isInitialized) { throw new UnsupportedOperationException( @@ -285,7 +288,7 @@ extends Serializable with Logging { * Generates a SparkStreaming job for the given time. This is an internal method that * should not be called directly. This default implementation creates a job * that materializes the corresponding RDD. Subclasses of DStream may override this - * (eg. PerRDDForEachDStream). + * (eg. ForEachDStream). */ protected[streaming] def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { @@ -420,65 +423,96 @@ extends Serializable with Logging { generatedRDDs = new HashMap[Time, RDD[T]] () } - /** - * -------------- - * DStream operations - * -------------- - */ + // ======================================================================= + // DStream operations + // ======================================================================= + + /** Returns a new DStream by applying a function to all elements of this DStream. */ def map[U: ClassManifest](mapFunc: T => U): DStream[U] = { new MappedDStream(this, ssc.sc.clean(mapFunc)) } + /** + * Returns a new DStream by applying a function to all elements of this DStream, + * and then flattening the results + */ def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = { new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) } + /** Returns a new DStream containing only the elements that satisfy a predicate. */ def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc) + /** + * Return a new DStream in which each RDD is generated by applying glom() to each RDD of + * this DStream. Applying glom() to an RDD coalesces all elements within each partition into + * an array. + */ def glom(): DStream[Array[T]] = new GlommedDStream(this) - def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]): DStream[U] = { - new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc)) + /** + * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs + * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition + * of the RDD. + */ + def mapPartitions[U: ClassManifest]( + mapPartFunc: Iterator[T] => Iterator[U], + preservePartitioning: Boolean = false + ): DStream[U] = { + new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc), preservePartitioning) } - def reduce(reduceFunc: (T, T) => T): DStream[T] = this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) + /** + * Returns a new DStream in which each RDD has a single element generated by reducing each RDD + * of this DStream. + */ + def reduce(reduceFunc: (T, T) => T): DStream[T] = + this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) + /** + * Returns a new DStream in which each RDD has a single element generated by counting each RDD + * of this DStream. + */ def count(): DStream[Int] = this.map(_ => 1).reduce(_ + _) - - def collect(): DStream[Seq[T]] = this.map(x => (null, x)).groupByKey(1).map(_._2) - - def foreach(foreachFunc: T => Unit) { - val newStream = new PerElementForEachDStream(this, ssc.sc.clean(foreachFunc)) - ssc.registerOutputStream(newStream) - newStream - } - def foreachRDD(foreachFunc: RDD[T] => Unit) { - foreachRDD((r: RDD[T], t: Time) => foreachFunc(r)) + /** + * Applies a function to each RDD in this DStream. This is an output operator, so + * this DStream will be registered as an output stream and therefore materialized. + */ + def foreach(foreachFunc: RDD[T] => Unit) { + foreach((r: RDD[T], t: Time) => foreachFunc(r)) } - def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) { - val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc)) + /** + * Applies a function to each RDD in this DStream. This is an output operator, so + * this DStream will be registered as an output stream and therefore materialized. + */ + def foreach(foreachFunc: (RDD[T], Time) => Unit) { + val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc)) ssc.registerOutputStream(newStream) newStream } - def transformRDD[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = { - transformRDD((r: RDD[T], t: Time) => transformFunc(r)) + /** + * Returns a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = { + transform((r: RDD[T], t: Time) => transformFunc(r)) } - def transformRDD[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { + /** + * Returns a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { new TransformedDStream(this, ssc.sc.clean(transformFunc)) } - def toBlockingQueue() = { - val queue = new ArrayBlockingQueue[RDD[T]](10000) - this.foreachRDD(rdd => { - queue.add(rdd) - }) - queue - } - + /** + * Prints the first ten elements of each RDD generated in this DStream. This is an output + * operator, so this DStream will be registered as an output stream and there materialized. + */ def print() { def foreachFunc = (rdd: RDD[T], time: Time) => { val first11 = rdd.take(11) @@ -489,18 +523,42 @@ extends Serializable with Logging { if (first11.size > 10) println("...") println() } - val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc)) + val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc)) ssc.registerOutputStream(newStream) } + /** + * Return a new DStream which is computed based on windowed batches of this DStream. + * The new DStream generates RDDs with the same interval as this DStream. + * @param windowTime width of the window; must be a multiple of this DStream's interval. + * @return + */ def window(windowTime: Time): DStream[T] = window(windowTime, this.slideTime) + /** + * Return a new DStream which is computed based on windowed batches of this DStream. + * @param windowTime duration (i.e., width) of the window; + * must be a multiple of this DStream's interval + * @param slideTime sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's interval + */ def window(windowTime: Time, slideTime: Time): DStream[T] = { new WindowedDStream(this, windowTime, slideTime) } + /** + * Returns a new DStream which computed based on tumbling window on this DStream. + * This is equivalent to window(batchTime, batchTime). + * @param batchTime tumbling window duration; must be a multiple of this DStream's interval + */ def tumble(batchTime: Time): DStream[T] = window(batchTime, batchTime) + /** + * Returns a new DStream in which each RDD has a single element generated by reducing all + * elements in a window over this DStream. windowTime and slideTime are as defined in the + * window() operation. This is equivalent to window(windowTime, slideTime).reduce(reduceFunc) + */ def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time): DStream[T] = { this.window(windowTime, slideTime).reduce(reduceFunc) } @@ -516,17 +574,31 @@ extends Serializable with Logging { .map(_._2) } + /** + * Returns a new DStream in which each RDD has a single element generated by counting the number + * of elements in a window over this DStream. windowTime and slideTime are as defined in the + * window() operation. This is equivalent to window(windowTime, slideTime).count() + */ def countByWindow(windowTime: Time, slideTime: Time): DStream[Int] = { this.map(_ => 1).reduceByWindow(_ + _, _ - _, windowTime, slideTime) } + /** + * Returns a new DStream by unifying data of another DStream with this DStream. + * @param that Another DStream having the same interval (i.e., slideTime) as this DStream. + */ def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that)) - def slice(interval: Interval): Seq[RDD[T]] = { + /** + * Returns all the RDDs defined by the Interval object (both end times included) + */ + protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = { slice(interval.beginTime, interval.endTime) } - // Get all the RDDs between fromTime to toTime (both included) + /** + * Returns all the RDDs between 'fromTime' to 'toTime' (both included) + */ def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() var time = toTime.floor(slideTime) @@ -540,20 +612,26 @@ extends Serializable with Logging { rdds.toSeq } + /** + * Saves each RDD in this DStream as a Sequence file of serialized objects. + */ def saveAsObjectFiles(prefix: String, suffix: String = "") { val saveFunc = (rdd: RDD[T], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsObjectFile(file) } - this.foreachRDD(saveFunc) + this.foreach(saveFunc) } + /** + * Saves each RDD in this DStream as at text file, using string representation of elements. + */ def saveAsTextFiles(prefix: String, suffix: String = "") { val saveFunc = (rdd: RDD[T], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(file) } - this.foreachRDD(saveFunc) + this.foreach(saveFunc) } def register() { @@ -561,6 +639,8 @@ extends Serializable with Logging { } } +private[streaming] +case class DStreamCheckpointData(rdds: HashMap[Time, Any]) abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) extends DStream[T](ssc_) { @@ -583,6 +663,7 @@ abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContex * TODO */ +private[streaming] class MappedDStream[T: ClassManifest, U: ClassManifest] ( parent: DStream[T], mapFunc: T => U @@ -602,6 +683,7 @@ class MappedDStream[T: ClassManifest, U: ClassManifest] ( * TODO */ +private[streaming] class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( parent: DStream[T], flatMapFunc: T => Traversable[U] @@ -621,6 +703,7 @@ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( * TODO */ +private[streaming] class FilteredDStream[T: ClassManifest]( parent: DStream[T], filterFunc: T => Boolean @@ -640,9 +723,11 @@ class FilteredDStream[T: ClassManifest]( * TODO */ +private[streaming] class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( parent: DStream[T], - mapPartFunc: Iterator[T] => Iterator[U] + mapPartFunc: Iterator[T] => Iterator[U], + preservePartitioning: Boolean ) extends DStream[U](parent.ssc) { override def dependencies = List(parent) @@ -650,7 +735,7 @@ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( override def slideTime: Time = parent.slideTime override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc)) + parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning)) } } @@ -659,6 +744,7 @@ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( * TODO */ +private[streaming] class GlommedDStream[T: ClassManifest](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { @@ -676,6 +762,7 @@ class GlommedDStream[T: ClassManifest](parent: DStream[T]) * TODO */ +private[streaming] class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( parent: DStream[(K,V)], createCombiner: V => C, @@ -702,6 +789,7 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( * TODO */ +private[streaming] class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( parent: DStream[(K, V)], mapValueFunc: V => U @@ -720,7 +808,7 @@ class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( /** * TODO */ - +private[streaming] class FlatMapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( parent: DStream[(K, V)], flatMapValueFunc: V => TraversableOnce[U] @@ -779,38 +867,8 @@ class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) * TODO */ -class PerElementForEachDStream[T: ClassManifest] ( - parent: DStream[T], - foreachFunc: T => Unit - ) extends DStream[Unit](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Unit]] = None - - override def generateJob(time: Time): Option[Job] = { - parent.getOrCompute(time) match { - case Some(rdd) => - val jobFunc = () => { - val sparkJobFunc = { - (iterator: Iterator[T]) => iterator.foreach(foreachFunc) - } - ssc.sc.runJob(rdd, sparkJobFunc) - } - Some(new Job(time, jobFunc)) - case None => None - } - } -} - - -/** - * TODO - */ - -class PerRDDForEachDStream[T: ClassManifest] ( +private[streaming] +class ForEachDStream[T: ClassManifest] ( parent: DStream[T], foreachFunc: (RDD[T], Time) => Unit ) extends DStream[Unit](parent.ssc) { @@ -838,6 +896,7 @@ class PerRDDForEachDStream[T: ClassManifest] ( * TODO */ +private[streaming] class TransformedDStream[T: ClassManifest, U: ClassManifest] ( parent: DStream[T], transformFunc: (RDD[T], Time) => RDD[U] diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala index 2959ce4540..5ac7e5b08e 100644 --- a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala @@ -79,7 +79,7 @@ class SparkFlumeEvent() extends Externalizable { } } -object SparkFlumeEvent { +private[streaming] object SparkFlumeEvent { def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { val event = new SparkFlumeEvent event.event = in diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala index ffb7725ac9..fa0b7ce19d 100644 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -1,5 +1,6 @@ package spark.streaming +private[streaming] case class Interval(beginTime: Time, endTime: Time) { def this(beginMs: Long, endMs: Long) = this(Time(beginMs), new Time(endMs)) diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala index 0bcb6fd8dc..67bd8388bc 100644 --- a/streaming/src/main/scala/spark/streaming/Job.scala +++ b/streaming/src/main/scala/spark/streaming/Job.scala @@ -2,6 +2,7 @@ package spark.streaming import java.util.concurrent.atomic.AtomicLong +private[streaming] class Job(val time: Time, func: () => _) { val id = Job.getNewId() def run(): Long = { @@ -14,6 +15,7 @@ class Job(val time: Time, func: () => _) { override def toString = "streaming job " + id + " @ " + time } +private[streaming] object Job { val id = new AtomicLong(0) diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 9bf9251519..fda7264a27 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -5,6 +5,7 @@ import spark.SparkEnv import java.util.concurrent.Executors +private[streaming] class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { class JobHandler(ssc: StreamingContext, job: Job) extends Runnable { diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index 4e4e9fc942..4bf13dd50c 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -40,10 +40,10 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming } -sealed trait NetworkReceiverMessage -case class StopReceiver(msg: String) extends NetworkReceiverMessage -case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage -case class ReportError(msg: String) extends NetworkReceiverMessage +private[streaming] sealed trait NetworkReceiverMessage +private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage +private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage +private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging { diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index b421f795ee..658498dfc1 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -11,10 +11,10 @@ import akka.pattern.ask import akka.util.duration._ import akka.dispatch._ -trait NetworkInputTrackerMessage -case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage -case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage -case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage +private[streaming] sealed trait NetworkInputTrackerMessage +private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage +private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage +private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage class NetworkInputTracker( diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 720e63bba0..f9fef14196 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -281,7 +281,7 @@ extends Serializable { val file = rddToFileName(prefix, suffix, time) rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) } - self.foreachRDD(saveFunc) + self.foreach(saveFunc) } def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]]( @@ -303,7 +303,7 @@ extends Serializable { val file = rddToFileName(prefix, suffix, time) rdd.saveAsNewAPIHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) } - self.foreachRDD(saveFunc) + self.foreach(saveFunc) } private def getKeyClass() = implicitly[ClassManifest[K]].erasure diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 014021be61..fd1fa77a24 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -7,11 +7,8 @@ import spark.Logging import scala.collection.mutable.HashMap -sealed trait SchedulerMessage -case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage - -class Scheduler(ssc: StreamingContext) -extends Logging { +private[streaming] +class Scheduler(ssc: StreamingContext) extends Logging { initLogging() diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index ce47bcb2da..998fea849f 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -48,7 +48,7 @@ class StreamingContext private ( this(StreamingContext.createNewSparkContext(master, frameworkName), null, batchDuration) /** - * Recreates the StreamingContext from a checkpoint file. + * Re-creates a StreamingContext from a checkpoint file. * @param path Path either to the directory that was specified as the checkpoint directory, or * to the checkpoint file 'graph' or 'graph.bk'. */ @@ -61,7 +61,7 @@ class StreamingContext private ( "both SparkContext and checkpoint as null") } - val isCheckpointPresent = (cp_ != null) + protected[streaming] val isCheckpointPresent = (cp_ != null) val sc: SparkContext = { if (isCheckpointPresent) { @@ -71,9 +71,9 @@ class StreamingContext private ( } } - val env = SparkEnv.get + protected[streaming] val env = SparkEnv.get - val graph: DStreamGraph = { + protected[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { cp_.graph.setContext(this) cp_.graph.restoreCheckpointData() @@ -86,10 +86,10 @@ class StreamingContext private ( } } - private[streaming] val nextNetworkInputStreamId = new AtomicInteger(0) - private[streaming] var networkInputTracker: NetworkInputTracker = null + protected[streaming] val nextNetworkInputStreamId = new AtomicInteger(0) + protected[streaming] var networkInputTracker: NetworkInputTracker = null - private[streaming] var checkpointDir: String = { + protected[streaming] var checkpointDir: String = { if (isCheckpointPresent) { sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(cp_.checkpointDir), true) cp_.checkpointDir @@ -98,9 +98,9 @@ class StreamingContext private ( } } - private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null - private[streaming] var receiverJobThread: Thread = null - private[streaming] var scheduler: Scheduler = null + protected[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null + protected[streaming] var receiverJobThread: Thread = null + protected[streaming] var scheduler: Scheduler = null def remember(duration: Time) { graph.remember(duration) @@ -117,11 +117,11 @@ class StreamingContext private ( } } - private[streaming] def getInitialCheckpoint(): Checkpoint = { + protected[streaming] def getInitialCheckpoint(): Checkpoint = { if (isCheckpointPresent) cp_ else null } - private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() + protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() /** * Create an input stream that pulls messages form a Kafka Broker. @@ -188,7 +188,7 @@ class StreamingContext private ( } /** - * This function creates a input stream that monitors a Hadoop-compatible filesystem + * Creates a input stream that monitors a Hadoop-compatible filesystem * for new files and executes the necessary processing on them. */ def fileStream[ @@ -206,7 +206,7 @@ class StreamingContext private ( } /** - * This function create a input stream from an queue of RDDs. In each batch, + * Creates a input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue */ def queueStream[T: ClassManifest]( @@ -231,22 +231,21 @@ class StreamingContext private ( } /** - * This function registers a InputDStream as an input stream that will be - * started (InputDStream.start() called) to get the input data streams. + * Registers an input stream that will be started (InputDStream.start() called) to get the + * input data. */ def registerInputStream(inputStream: InputDStream[_]) { graph.addInputStream(inputStream) } /** - * This function registers a DStream as an output stream that will be - * computed every interval. + * Registers an output stream that will be computed every interval */ def registerOutputStream(outputStream: DStream[_]) { graph.addOutputStream(outputStream) } - def validate() { + protected def validate() { assert(graph != null, "Graph is null") graph.validate() @@ -304,7 +303,7 @@ class StreamingContext private ( object StreamingContext { - def createNewSparkContext(master: String, frameworkName: String): SparkContext = { + protected[streaming] def createNewSparkContext(master: String, frameworkName: String): SparkContext = { // Set the default cleaner delay to an hour if not already set. // This should be sufficient for even 1 second interval. @@ -318,7 +317,7 @@ object StreamingContext { new PairDStreamFunctions[K, V](stream) } - def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { + protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { if (prefix == null) { time.millis.toString } else if (suffix == null || suffix.length ==0) { @@ -328,7 +327,7 @@ object StreamingContext { } } - def getSparkCheckpointDir(sscCheckpointDir: String): String = { + protected[streaming] def getSparkCheckpointDir(sscCheckpointDir: String): String = { new Path(sscCheckpointDir, UUID.randomUUID.toString).toString } } diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala index 6cb2b4c042..7c4ee3b34c 100644 --- a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -26,7 +26,7 @@ object GrepRaw { val rawStreams = (1 to numStreams).map(_ => ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray val union = new UnionDStream(rawStreams) - union.filter(_.contains("Alice")).count().foreachRDD(r => + union.filter(_.contains("Alice")).count().foreach(r => println("Grep count: " + r.collect().mkString)) ssc.start() } diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index fe4c2bf155..182dfd8a52 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -38,7 +38,7 @@ object TopKWordCountRaw { val counts = union.mapPartitions(splitAndCountPartitions) val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { + partialTopKWindowedCounts.foreach(rdd => { val collectedCounts = rdd.collect println("Collected " + collectedCounts.size + " words from partial top words") println("Top " + k + " words are " + topK(collectedCounts.toIterator, k).mkString(",")) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index a29c81d437..9bcd30f4d7 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -36,7 +36,7 @@ object WordCountRaw { val union = new UnionDStream(lines.toArray) val counts = union.mapPartitions(splitAndCountPartitions) val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) - windowedCounts.foreachRDD(r => println("# unique words = " + r.count())) + windowedCounts.foreach(r => println("# unique words = " + r.count())) ssc.start() } diff --git a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala index 68be6b7893..a191321d91 100644 --- a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala +++ b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala @@ -72,7 +72,7 @@ object PageViewStream { case "popularUsersSeen" => // Look for users in our existing dataset and print it out if we have a match pageViews.map(view => (view.userID, 1)) - .foreachRDD((rdd, time) => rdd.join(userList) + .foreach((rdd, time) => rdd.join(userList) .map(_._2._2) .take(10) .foreach(u => println("Saw user %s at time %s".format(u, time)))) diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index 8cc2f8ccfc..a44f738957 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -35,7 +35,7 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. */ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) - extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { + extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected }) { -- cgit v1.2.3 From 7e0271b4387eaf27cd96f3057ce2465b1271a480 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 30 Dec 2012 15:19:55 -0800 Subject: Refactored a whole lot to push all DStreams into the spark.streaming.dstream package. --- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 1 + .../scala/spark/streaming/CoGroupedDStream.scala | 38 --- .../spark/streaming/ConstantInputDStream.scala | 18 -- .../src/main/scala/spark/streaming/DStream.scala | 276 +-------------------- .../main/scala/spark/streaming/DStreamGraph.scala | 1 + .../main/scala/spark/streaming/DataHandler.scala | 83 ------- .../scala/spark/streaming/FileInputDStream.scala | 109 -------- .../scala/spark/streaming/FlumeInputDStream.scala | 130 ---------- .../spark/streaming/NetworkInputDStream.scala | 156 ------------ .../spark/streaming/NetworkInputTracker.scala | 2 + .../spark/streaming/PairDStreamFunctions.scala | 7 +- .../scala/spark/streaming/QueueInputDStream.scala | 40 --- .../scala/spark/streaming/RawInputDStream.scala | 85 ------- .../spark/streaming/ReducedWindowedDStream.scala | 149 ----------- .../src/main/scala/spark/streaming/Scheduler.scala | 3 - .../scala/spark/streaming/SocketInputDStream.scala | 107 -------- .../main/scala/spark/streaming/StateDStream.scala | 84 ------- .../scala/spark/streaming/StreamingContext.scala | 13 +- .../src/main/scala/spark/streaming/Time.scala | 11 +- .../scala/spark/streaming/WindowedDStream.scala | 39 --- .../spark/streaming/dstream/CoGroupedDStream.scala | 39 +++ .../streaming/dstream/ConstantInputDStream.scala | 19 ++ .../spark/streaming/dstream/DataHandler.scala | 83 +++++++ .../spark/streaming/dstream/FileInputDStream.scala | 110 ++++++++ .../spark/streaming/dstream/FilteredDStream.scala | 21 ++ .../streaming/dstream/FlatMapValuedDStream.scala | 20 ++ .../streaming/dstream/FlatMappedDStream.scala | 20 ++ .../streaming/dstream/FlumeInputDStream.scala | 135 ++++++++++ .../spark/streaming/dstream/ForEachDStream.scala | 28 +++ .../spark/streaming/dstream/GlommedDStream.scala | 17 ++ .../spark/streaming/dstream/InputDStream.scala | 19 ++ .../streaming/dstream/KafkaInputDStream.scala | 197 +++++++++++++++ .../streaming/dstream/MapPartitionedDStream.scala | 21 ++ .../spark/streaming/dstream/MapValuedDStream.scala | 21 ++ .../spark/streaming/dstream/MappedDStream.scala | 20 ++ .../streaming/dstream/NetworkInputDStream.scala | 157 ++++++++++++ .../streaming/dstream/QueueInputDStream.scala | 41 +++ .../spark/streaming/dstream/RawInputDStream.scala | 88 +++++++ .../streaming/dstream/ReducedWindowedDStream.scala | 148 +++++++++++ .../spark/streaming/dstream/ShuffledDStream.scala | 27 ++ .../streaming/dstream/SocketInputDStream.scala | 103 ++++++++ .../spark/streaming/dstream/StateDStream.scala | 83 +++++++ .../streaming/dstream/TransformedDStream.scala | 19 ++ .../spark/streaming/dstream/UnionDStream.scala | 39 +++ .../spark/streaming/dstream/WindowedDStream.scala | 40 +++ .../scala/spark/streaming/examples/GrepRaw.scala | 2 +- .../streaming/examples/TopKWordCountRaw.scala | 2 +- .../spark/streaming/examples/WordCountRaw.scala | 2 +- .../spark/streaming/input/KafkaInputDStream.scala | 193 -------------- .../scala/spark/streaming/CheckpointSuite.scala | 2 +- .../test/scala/spark/streaming/FailureSuite.scala | 2 +- .../scala/spark/streaming/InputStreamsSuite.scala | 1 + .../test/scala/spark/streaming/TestSuiteBase.scala | 48 +++- .../spark/streaming/WindowOperationsSuite.scala | 12 +- 54 files changed, 1600 insertions(+), 1531 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DataHandler.scala delete mode 100644 streaming/src/main/scala/spark/streaming/FileInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/QueueInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/RawInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/SocketInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/StateDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/WindowedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/DataHandler.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index f40b56be64..1b219473e0 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -1,6 +1,7 @@ package spark.rdd import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext} +import spark.SparkContext._ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { override val index = idx diff --git a/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala deleted file mode 100644 index 61d088eddb..0000000000 --- a/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala +++ /dev/null @@ -1,38 +0,0 @@ -package spark.streaming - -import spark.{RDD, Partitioner} -import spark.rdd.CoGroupedRDD - -class CoGroupedDStream[K : ClassManifest]( - parents: Seq[DStream[(_, _)]], - partitioner: Partitioner - ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) { - - if (parents.length == 0) { - throw new IllegalArgumentException("Empty array of parents") - } - - if (parents.map(_.ssc).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different StreamingContexts") - } - - if (parents.map(_.slideTime).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different slide times") - } - - override def dependencies = parents.toList - - override def slideTime = parents.head.slideTime - - override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = { - val part = partitioner - val rdds = parents.flatMap(_.getOrCompute(validTime)) - if (rdds.size > 0) { - val q = new CoGroupedRDD[K](rdds, part) - Some(q) - } else { - None - } - } - -} diff --git a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala deleted file mode 100644 index 80150708fd..0000000000 --- a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala +++ /dev/null @@ -1,18 +0,0 @@ -package spark.streaming - -import spark.RDD - -/** - * An input stream that always returns the same RDD on each timestep. Useful for testing. - */ -class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T]) - extends InputDStream[T](ssc_) { - - override def start() {} - - override def stop() {} - - override def compute(validTime: Time): Option[RDD[T]] = { - Some(rdd) - } -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 3834b57ed3..292ad3b9f9 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -1,17 +1,15 @@ package spark.streaming +import spark.streaming.dstream._ import StreamingContext._ import Time._ -import spark._ -import spark.SparkContext._ -import spark.rdd._ +import spark.{RDD, Logging} import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import java.util.concurrent.ArrayBlockingQueue import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import org.apache.hadoop.fs.Path @@ -197,7 +195,7 @@ abstract class DStream[T: ClassManifest] ( "than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" + "delay is set to " + (metadataCleanerDelay / 60.0) + " minutes, which is not sufficient. Please set " + "the Java property 'spark.cleaner.delay' to more than " + - math.ceil(rememberDuration.millis.toDouble / 60000.0).toInt + " minutes." + math.ceil(rememberDuration.milliseconds.toDouble / 60000.0).toInt + " minutes." ) dependencies.foreach(_.validate()) @@ -642,271 +640,3 @@ abstract class DStream[T: ClassManifest] ( private[streaming] case class DStreamCheckpointData(rdds: HashMap[Time, Any]) -abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) - extends DStream[T](ssc_) { - - override def dependencies = List() - - override def slideTime = { - if (ssc == null) throw new Exception("ssc is null") - if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null") - ssc.graph.batchDuration - } - - def start() - - def stop() -} - - -/** - * TODO - */ - -private[streaming] -class MappedDStream[T: ClassManifest, U: ClassManifest] ( - parent: DStream[T], - mapFunc: T => U - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.map[U](mapFunc)) - } -} - - -/** - * TODO - */ - -private[streaming] -class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( - parent: DStream[T], - flatMapFunc: T => Traversable[U] - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) - } -} - - -/** - * TODO - */ - -private[streaming] -class FilteredDStream[T: ClassManifest]( - parent: DStream[T], - filterFunc: T => Boolean - ) extends DStream[T](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[T]] = { - parent.getOrCompute(validTime).map(_.filter(filterFunc)) - } -} - - -/** - * TODO - */ - -private[streaming] -class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( - parent: DStream[T], - mapPartFunc: Iterator[T] => Iterator[U], - preservePartitioning: Boolean - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning)) - } -} - - -/** - * TODO - */ - -private[streaming] -class GlommedDStream[T: ClassManifest](parent: DStream[T]) - extends DStream[Array[T]](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Array[T]]] = { - parent.getOrCompute(validTime).map(_.glom()) - } -} - - -/** - * TODO - */ - -private[streaming] -class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( - parent: DStream[(K,V)], - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - partitioner: Partitioner - ) extends DStream [(K,C)] (parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[(K,C)]] = { - parent.getOrCompute(validTime) match { - case Some(rdd) => - Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner)) - case None => None - } - } -} - - -/** - * TODO - */ - -private[streaming] -class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( - parent: DStream[(K, V)], - mapValueFunc: V => U - ) extends DStream[(K, U)](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[(K, U)]] = { - parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc)) - } -} - - -/** - * TODO - */ -private[streaming] -class FlatMapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( - parent: DStream[(K, V)], - flatMapValueFunc: V => TraversableOnce[U] - ) extends DStream[(K, U)](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[(K, U)]] = { - parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc)) - } -} - - - -/** - * TODO - */ - -class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) - extends DStream[T](parents.head.ssc) { - - if (parents.length == 0) { - throw new IllegalArgumentException("Empty array of parents") - } - - if (parents.map(_.ssc).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different StreamingContexts") - } - - if (parents.map(_.slideTime).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different slide times") - } - - override def dependencies = parents.toList - - override def slideTime: Time = parents.head.slideTime - - override def compute(validTime: Time): Option[RDD[T]] = { - val rdds = new ArrayBuffer[RDD[T]]() - parents.map(_.getOrCompute(validTime)).foreach(_ match { - case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) - }) - if (rdds.size > 0) { - Some(new UnionRDD(ssc.sc, rdds)) - } else { - None - } - } -} - - -/** - * TODO - */ - -private[streaming] -class ForEachDStream[T: ClassManifest] ( - parent: DStream[T], - foreachFunc: (RDD[T], Time) => Unit - ) extends DStream[Unit](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Unit]] = None - - override def generateJob(time: Time): Option[Job] = { - parent.getOrCompute(time) match { - case Some(rdd) => - val jobFunc = () => { - foreachFunc(rdd, time) - } - Some(new Job(time, jobFunc)) - case None => None - } - } -} - - -/** - * TODO - */ - -private[streaming] -class TransformedDStream[T: ClassManifest, U: ClassManifest] ( - parent: DStream[T], - transformFunc: (RDD[T], Time) => RDD[U] - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(transformFunc(_, validTime)) - } -} diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index d0a9ade61d..c72429370e 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -1,5 +1,6 @@ package spark.streaming +import dstream.InputDStream import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import collection.mutable.ArrayBuffer import spark.Logging diff --git a/streaming/src/main/scala/spark/streaming/DataHandler.scala b/streaming/src/main/scala/spark/streaming/DataHandler.scala deleted file mode 100644 index 05f307a8d1..0000000000 --- a/streaming/src/main/scala/spark/streaming/DataHandler.scala +++ /dev/null @@ -1,83 +0,0 @@ -package spark.streaming - -import java.util.concurrent.ArrayBlockingQueue -import scala.collection.mutable.ArrayBuffer -import spark.Logging -import spark.streaming.util.{RecurringTimer, SystemClock} -import spark.storage.StorageLevel - - -/** - * This is a helper object that manages the data received from the socket. It divides - * the object received into small batches of 100s of milliseconds, pushes them as - * blocks into the block manager and reports the block IDs to the network input - * tracker. It starts two threads, one to periodically start a new batch and prepare - * the previous batch of as a block, the other to push the blocks into the block - * manager. - */ - class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel) - extends Serializable with Logging { - - case class Block(id: String, iterator: Iterator[T], metadata: Any = null) - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockStorageLevel = storageLevel - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def createBlock(blockId: String, iterator: Iterator[T]) : Block = { - new Block(blockId, iterator) - } - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - logInfo("Data handler stopped") - } - - def += (obj: T) { - currentBuffer += obj - } - - def updateCurrentBuffer(time: Long) { - try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval) - val newBlock = createBlock(blockId, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - } - } catch { - case ie: InterruptedException => - logInfo("Block interval timer thread interrupted") - case e: Exception => - receiver.stop() - } - } - - def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel) - } - } catch { - case ie: InterruptedException => - logInfo("Block pushing thread interrupted") - case e: Exception => - receiver.stop() - } - } - } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala deleted file mode 100644 index 88856364d2..0000000000 --- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala +++ /dev/null @@ -1,109 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.rdd.UnionRDD - -import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} - -import scala.collection.mutable.HashSet - - -class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( - @transient ssc_ : StreamingContext, - directory: String, - filter: PathFilter = FileInputDStream.defaultPathFilter, - newFilesOnly: Boolean = true) - extends InputDStream[(K, V)](ssc_) { - - @transient private var path_ : Path = null - @transient private var fs_ : FileSystem = null - - var lastModTime = 0L - val lastModTimeFiles = new HashSet[String]() - - def path(): Path = { - if (path_ == null) path_ = new Path(directory) - path_ - } - - def fs(): FileSystem = { - if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) - fs_ - } - - override def start() { - if (newFilesOnly) { - lastModTime = System.currentTimeMillis() - } else { - lastModTime = 0 - } - } - - override def stop() { } - - /** - * Finds the files that were modified since the last time this method was called and makes - * a union RDD out of them. Note that this maintains the list of files that were processed - * in the latest modification time in the previous call to this method. This is because the - * modification time returned by the FileStatus API seems to return times only at the - * granularity of seconds. Hence, new files may have the same modification time as the - * latest modification time in the previous call to this method and the list of files - * maintained is used to filter the one that have been processed. - */ - override def compute(validTime: Time): Option[RDD[(K, V)]] = { - // Create the filter for selecting new files - val newFilter = new PathFilter() { - var latestModTime = 0L - val latestModTimeFiles = new HashSet[String]() - - def accept(path: Path): Boolean = { - if (!filter.accept(path)) { - return false - } else { - val modTime = fs.getFileStatus(path).getModificationTime() - if (modTime < lastModTime){ - return false - } else if (modTime == lastModTime && lastModTimeFiles.contains(path.toString)) { - return false - } - if (modTime > latestModTime) { - latestModTime = modTime - latestModTimeFiles.clear() - } - latestModTimeFiles += path.toString - return true - } - } - } - - val newFiles = fs.listStatus(path, newFilter) - logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) - if (newFiles.length > 0) { - // Update the modification time and the files processed for that modification time - if (lastModTime != newFilter.latestModTime) { - lastModTime = newFilter.latestModTime - lastModTimeFiles.clear() - } - lastModTimeFiles ++= newFilter.latestModTimeFiles - } - val newRDD = new UnionRDD(ssc.sc, newFiles.map( - file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) - Some(newRDD) - } -} - -object FileInputDStream { - val defaultPathFilter = new PathFilter with Serializable { - def accept(path: Path): Boolean = { - val file = path.getName() - if (file.startsWith(".") || file.endsWith("_tmp")) { - return false - } else { - return true - } - } - } -} - diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala deleted file mode 100644 index 5ac7e5b08e..0000000000 --- a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala +++ /dev/null @@ -1,130 +0,0 @@ -package spark.streaming - -import java.io.{ObjectInput, ObjectOutput, Externalizable} -import spark.storage.StorageLevel -import org.apache.flume.source.avro.AvroSourceProtocol -import org.apache.flume.source.avro.AvroFlumeEvent -import org.apache.flume.source.avro.Status -import org.apache.avro.ipc.specific.SpecificResponder -import org.apache.avro.ipc.NettyServer -import java.net.InetSocketAddress -import collection.JavaConversions._ -import spark.Utils -import java.nio.ByteBuffer - -class FlumeInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - storageLevel: StorageLevel -) extends NetworkInputDStream[SparkFlumeEvent](ssc_) { - - override def createReceiver(): NetworkReceiver[SparkFlumeEvent] = { - new FlumeReceiver(id, host, port, storageLevel) - } -} - -/** - * A wrapper class for AvroFlumeEvent's with a custom serialization format. - * - * This is necessary because AvroFlumeEvent uses inner data structures - * which are not serializable. - */ -class SparkFlumeEvent() extends Externalizable { - var event : AvroFlumeEvent = new AvroFlumeEvent() - - /* De-serialize from bytes. */ - def readExternal(in: ObjectInput) { - val bodyLength = in.readInt() - val bodyBuff = new Array[Byte](bodyLength) - in.read(bodyBuff) - - val numHeaders = in.readInt() - val headers = new java.util.HashMap[CharSequence, CharSequence] - - for (i <- 0 until numHeaders) { - val keyLength = in.readInt() - val keyBuff = new Array[Byte](keyLength) - in.read(keyBuff) - val key : String = Utils.deserialize(keyBuff) - - val valLength = in.readInt() - val valBuff = new Array[Byte](valLength) - in.read(valBuff) - val value : String = Utils.deserialize(valBuff) - - headers.put(key, value) - } - - event.setBody(ByteBuffer.wrap(bodyBuff)) - event.setHeaders(headers) - } - - /* Serialize to bytes. */ - def writeExternal(out: ObjectOutput) { - val body = event.getBody.array() - out.writeInt(body.length) - out.write(body) - - val numHeaders = event.getHeaders.size() - out.writeInt(numHeaders) - for ((k, v) <- event.getHeaders) { - val keyBuff = Utils.serialize(k.toString) - out.writeInt(keyBuff.length) - out.write(keyBuff) - val valBuff = Utils.serialize(v.toString) - out.writeInt(valBuff.length) - out.write(valBuff) - } - } -} - -private[streaming] object SparkFlumeEvent { - def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { - val event = new SparkFlumeEvent - event.event = in - event - } -} - -/** A simple server that implements Flume's Avro protocol. */ -class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { - override def append(event : AvroFlumeEvent) : Status = { - receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event) - Status.OK - } - - override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { - events.foreach (event => - receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event)) - Status.OK - } -} - -/** A NetworkReceiver which listens for events using the - * Flume Avro interface.*/ -class FlumeReceiver( - streamId: Int, - host: String, - port: Int, - storageLevel: StorageLevel - ) extends NetworkReceiver[SparkFlumeEvent](streamId) { - - lazy val dataHandler = new DataHandler(this, storageLevel) - - protected override def onStart() { - val responder = new SpecificResponder( - classOf[AvroSourceProtocol], new FlumeEventServer(this)); - val server = new NettyServer(responder, new InetSocketAddress(host, port)); - dataHandler.start() - server.start() - logInfo("Flume receiver started") - } - - protected override def onStop() { - dataHandler.stop() - logInfo("Flume receiver stopped") - } - - override def getLocationPreference = Some(host) -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala deleted file mode 100644 index 4bf13dd50c..0000000000 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ /dev/null @@ -1,156 +0,0 @@ -package spark.streaming - -import scala.collection.mutable.ArrayBuffer - -import spark.{Logging, SparkEnv, RDD} -import spark.rdd.BlockRDD -import spark.streaming.util.{RecurringTimer, SystemClock} -import spark.storage.StorageLevel - -import java.nio.ByteBuffer - -import akka.actor.{Props, Actor} -import akka.pattern.ask -import akka.dispatch.Await -import akka.util.duration._ - -abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext) - extends InputDStream[T](ssc_) { - - // This is an unique identifier that is used to match the network receiver with the - // corresponding network input stream. - val id = ssc.getNewNetworkStreamId() - - /** - * This method creates the receiver object that will be sent to the workers - * to receive data. This method needs to defined by any specific implementation - * of a NetworkInputDStream. - */ - def createReceiver(): NetworkReceiver[T] - - // Nothing to start or stop as both taken care of by the NetworkInputTracker. - def start() {} - - def stop() {} - - override def compute(validTime: Time): Option[RDD[T]] = { - val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) - Some(new BlockRDD[T](ssc.sc, blockIds)) - } -} - - -private[streaming] sealed trait NetworkReceiverMessage -private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage -private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage -private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage - -abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging { - - initLogging() - - lazy protected val env = SparkEnv.get - - lazy protected val actor = env.actorSystem.actorOf( - Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) - - lazy protected val receivingThread = Thread.currentThread() - - /** This method will be called to start receiving data. */ - protected def onStart() - - /** This method will be called to stop receiving data. */ - protected def onStop() - - /** This method conveys a placement preference (hostname) for this receiver. */ - def getLocationPreference() : Option[String] = None - - /** - * This method starts the receiver. First is accesses all the lazy members to - * materialize them. Then it calls the user-defined onStart() method to start - * other threads, etc required to receiver the data. - */ - def start() { - try { - // Access the lazy vals to materialize them - env - actor - receivingThread - - // Call user-defined onStart() - onStart() - } catch { - case ie: InterruptedException => - logInfo("Receiving thread interrupted") - //println("Receiving thread interrupted") - case e: Exception => - stopOnError(e) - } - } - - /** - * This method stops the receiver. First it interrupts the main receiving thread, - * that is, the thread that called receiver.start(). Then it calls the user-defined - * onStop() method to stop other threads and/or do cleanup. - */ - def stop() { - receivingThread.interrupt() - onStop() - //TODO: terminate the actor - } - - /** - * This method stops the receiver and reports to exception to the tracker. - * This should be called whenever an exception has happened on any thread - * of the receiver. - */ - protected def stopOnError(e: Exception) { - logError("Error receiving data", e) - stop() - actor ! ReportError(e.toString) - } - - - /** - * This method pushes a block (as iterator of values) into the block manager. - */ - def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) { - val buffer = new ArrayBuffer[T] ++ iterator - env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level) - - actor ! ReportBlock(blockId, metadata) - } - - /** - * This method pushes a block (as bytes) into the block manager. - */ - def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { - env.blockManager.putBytes(blockId, bytes, level) - actor ! ReportBlock(blockId, metadata) - } - - /** A helper actor that communicates with the NetworkInputTracker */ - private class NetworkReceiverActor extends Actor { - logInfo("Attempting to register with tracker") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) - val tracker = env.actorSystem.actorFor(url) - val timeout = 5.seconds - - override def preStart() { - val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) - } - - override def receive() = { - case ReportBlock(blockId, metadata) => - tracker ! AddBlocks(streamId, Array(blockId), metadata) - case ReportError(msg) => - tracker ! DeregisterReceiver(streamId, msg) - case StopReceiver(msg) => - stop() - tracker ! DeregisterReceiver(streamId, msg) - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 658498dfc1..a6ab44271f 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -1,5 +1,7 @@ package spark.streaming +import spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} +import spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} import spark.Logging import spark.SparkEnv diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index f9fef14196..b0a208e67f 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -1,6 +1,9 @@ package spark.streaming import spark.streaming.StreamingContext._ +import spark.streaming.dstream.{ReducedWindowedDStream, StateDStream} +import spark.streaming.dstream.{CoGroupedDStream, ShuffledDStream} +import spark.streaming.dstream.{MapValuedDStream, FlatMapValuedDStream} import spark.{Manifests, RDD, Partitioner, HashPartitioner} import spark.SparkContext._ @@ -218,13 +221,13 @@ extends Serializable { def mapValues[U: ClassManifest](mapValuesFunc: V => U): DStream[(K, U)] = { - new MapValuesDStream[K, V, U](self, mapValuesFunc) + new MapValuedDStream[K, V, U](self, mapValuesFunc) } def flatMapValues[U: ClassManifest]( flatMapValuesFunc: V => TraversableOnce[U] ): DStream[(K, U)] = { - new FlatMapValuesDStream[K, V, U](self, flatMapValuesFunc) + new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc) } def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { diff --git a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala deleted file mode 100644 index bb86e51932..0000000000 --- a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala +++ /dev/null @@ -1,40 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.rdd.UnionRDD - -import scala.collection.mutable.Queue -import scala.collection.mutable.ArrayBuffer - -class QueueInputDStream[T: ClassManifest]( - @transient ssc: StreamingContext, - val queue: Queue[RDD[T]], - oneAtATime: Boolean, - defaultRDD: RDD[T] - ) extends InputDStream[T](ssc) { - - override def start() { } - - override def stop() { } - - override def compute(validTime: Time): Option[RDD[T]] = { - val buffer = new ArrayBuffer[RDD[T]]() - if (oneAtATime && queue.size > 0) { - buffer += queue.dequeue() - } else { - buffer ++= queue - } - if (buffer.size > 0) { - if (oneAtATime) { - Some(buffer.first) - } else { - Some(new UnionRDD(ssc.sc, buffer.toSeq)) - } - } else if (defaultRDD != null) { - Some(defaultRDD) - } else { - None - } - } - -} diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala deleted file mode 100644 index 6acaa9aab1..0000000000 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ /dev/null @@ -1,85 +0,0 @@ -package spark.streaming - -import java.net.InetSocketAddress -import java.nio.ByteBuffer -import java.nio.channels.{ReadableByteChannel, SocketChannel} -import java.io.EOFException -import java.util.concurrent.ArrayBlockingQueue -import spark._ -import spark.storage.StorageLevel - -/** - * An input stream that reads blocks of serialized objects from a given network address. - * The blocks will be inserted directly into the block store. This is the fastest way to get - * data into Spark Streaming, though it requires the sender to batch data and serialize it - * in the format that the system is configured with. - */ -class RawInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_ ) with Logging { - - def createReceiver(): NetworkReceiver[T] = { - new RawNetworkReceiver(id, host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] - } -} - -class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel) - extends NetworkReceiver[Any](streamId) { - - var blockPushingThread: Thread = null - - override def getLocationPreference = None - - def onStart() { - // Open a socket to the target address and keep reading from it - logInfo("Connecting to " + host + ":" + port) - val channel = SocketChannel.open() - channel.configureBlocking(true) - channel.connect(new InetSocketAddress(host, port)) - logInfo("Connected to " + host + ":" + port) - - val queue = new ArrayBlockingQueue[ByteBuffer](2) - - blockPushingThread = new DaemonThread { - override def run() { - var nextBlockNumber = 0 - while (true) { - val buffer = queue.take() - val blockId = "input-" + streamId + "-" + nextBlockNumber - nextBlockNumber += 1 - pushBlock(blockId, buffer, null, storageLevel) - } - } - } - blockPushingThread.start() - - val lengthBuffer = ByteBuffer.allocate(4) - while (true) { - lengthBuffer.clear() - readFully(channel, lengthBuffer) - lengthBuffer.flip() - val length = lengthBuffer.getInt() - val dataBuffer = ByteBuffer.allocate(length) - readFully(channel, dataBuffer) - dataBuffer.flip() - logInfo("Read a block with " + length + " bytes") - queue.put(dataBuffer) - } - } - - def onStop() { - if (blockPushingThread != null) blockPushingThread.interrupt() - } - - /** Read a buffer fully from a given Channel */ - private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { - while (dest.position < dest.limit) { - if (channel.read(dest) == -1) { - throw new EOFException("End of channel") - } - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala deleted file mode 100644 index f63a9e0011..0000000000 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ /dev/null @@ -1,149 +0,0 @@ -package spark.streaming - -import spark.streaming.StreamingContext._ - -import spark.RDD -import spark.rdd.UnionRDD -import spark.rdd.CoGroupedRDD -import spark.Partitioner -import spark.SparkContext._ -import spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer -import collection.SeqProxy - -class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( - parent: DStream[(K, V)], - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - _windowTime: Time, - _slideTime: Time, - partitioner: Partitioner - ) extends DStream[(K,V)](parent.ssc) { - - assert(_windowTime.isMultipleOf(parent.slideTime), - "The window duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" - ) - - assert(_slideTime.isMultipleOf(parent.slideTime), - "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" - ) - - // Reduce each batch of data using reduceByKey which will be further reduced by window - // by ReducedWindowedDStream - val reducedStream = parent.reduceByKey(reduceFunc, partitioner) - - // Persist RDDs to memory by default as these RDDs are going to be reused. - super.persist(StorageLevel.MEMORY_ONLY_SER) - reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) - - def windowTime: Time = _windowTime - - override def dependencies = List(reducedStream) - - override def slideTime: Time = _slideTime - - override val mustCheckpoint = true - - override def parentRememberDuration: Time = rememberDuration + windowTime - - override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { - super.persist(storageLevel) - reducedStream.persist(storageLevel) - this - } - - override def checkpoint(interval: Time): DStream[(K, V)] = { - super.checkpoint(interval) - //reducedStream.checkpoint(interval) - this - } - - override def compute(validTime: Time): Option[RDD[(K, V)]] = { - val reduceF = reduceFunc - val invReduceF = invReduceFunc - - val currentTime = validTime - val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime) - val previousWindow = currentWindow - slideTime - - logDebug("Window time = " + windowTime) - logDebug("Slide time = " + slideTime) - logDebug("ZeroTime = " + zeroTime) - logDebug("Current window = " + currentWindow) - logDebug("Previous window = " + previousWindow) - - // _____________________________ - // | previous window _________|___________________ - // |___________________| current window | --------------> Time - // |_____________________________| - // - // |________ _________| |________ _________| - // | | - // V V - // old RDDs new RDDs - // - - // Get the RDDs of the reduced values in "old time steps" - val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideTime) - logDebug("# old RDDs = " + oldRDDs.size) - - // Get the RDDs of the reduced values in "new time steps" - val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideTime, currentWindow.endTime) - logDebug("# new RDDs = " + newRDDs.size) - - // Get the RDD of the reduced value of the previous window - val previousWindowRDD = getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) - - // Make the list of RDDs that needs to cogrouped together for reducing their reduced values - val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs - - // Cogroup the reduced RDDs and merge the reduced values - val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) - //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ - - val numOldValues = oldRDDs.size - val numNewValues = newRDDs.size - - val mergeValues = (seqOfValues: Seq[Seq[V]]) => { - if (seqOfValues.size != 1 + numOldValues + numNewValues) { - throw new Exception("Unexpected number of sequences of reduced values") - } - // Getting reduced values "old time steps" that will be removed from current window - val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) - // Getting reduced values "new time steps" - val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) - if (seqOfValues(0).isEmpty) { - // If previous window's reduce value does not exist, then at least new values should exist - if (newValues.isEmpty) { - throw new Exception("Neither previous window has value for key, nor new values found. " + - "Are you sure your key class hashes consistently?") - } - // Reduce the new values - newValues.reduce(reduceF) // return - } else { - // Get the previous window's reduced value - var tempValue = seqOfValues(0).head - // If old values exists, then inverse reduce then from previous value - if (!oldValues.isEmpty) { - tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) - } - // If new values exists, then reduce them with previous value - if (!newValues.isEmpty) { - tempValue = reduceF(tempValue, newValues.reduce(reduceF)) - } - tempValue // return - } - } - - val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) - - Some(mergedValuesRDD) - } - - -} - - diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index fd1fa77a24..aeb7c3eb0e 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -4,9 +4,6 @@ import util.{ManualClock, RecurringTimer, Clock} import spark.SparkEnv import spark.Logging -import scala.collection.mutable.HashMap - - private[streaming] class Scheduler(ssc: StreamingContext) extends Logging { diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala deleted file mode 100644 index a9e37c0ff0..0000000000 --- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala +++ /dev/null @@ -1,107 +0,0 @@ -package spark.streaming - -import spark.streaming.util.{RecurringTimer, SystemClock} -import spark.storage.StorageLevel - -import java.io._ -import java.net.Socket -import java.util.concurrent.ArrayBlockingQueue - -import scala.collection.mutable.ArrayBuffer -import scala.Serializable - -class SocketInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - bytesToObjects: InputStream => Iterator[T], - storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_) { - - def createReceiver(): NetworkReceiver[T] = { - new SocketReceiver(id, host, port, bytesToObjects, storageLevel) - } -} - - -class SocketReceiver[T: ClassManifest]( - streamId: Int, - host: String, - port: Int, - bytesToObjects: InputStream => Iterator[T], - storageLevel: StorageLevel - ) extends NetworkReceiver[T](streamId) { - - lazy protected val dataHandler = new DataHandler(this, storageLevel) - - override def getLocationPreference = None - - protected def onStart() { - logInfo("Connecting to " + host + ":" + port) - val socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) - dataHandler.start() - val iterator = bytesToObjects(socket.getInputStream()) - while(iterator.hasNext) { - val obj = iterator.next - dataHandler += obj - } - } - - protected def onStop() { - dataHandler.stop() - } - -} - - -object SocketReceiver { - - /** - * This methods translates the data from an inputstream (say, from a socket) - * to '\n' delimited strings and returns an iterator to access the strings. - */ - def bytesToLines(inputStream: InputStream): Iterator[String] = { - val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")) - - val iterator = new Iterator[String] { - var gotNext = false - var finished = false - var nextValue: String = null - - private def getNext() { - try { - nextValue = dataInputStream.readLine() - if (nextValue == null) { - finished = true - } - } - gotNext = true - } - - override def hasNext: Boolean = { - if (!finished) { - if (!gotNext) { - getNext() - if (finished) { - dataInputStream.close() - } - } - } - !finished - } - - override def next(): String = { - if (finished) { - throw new NoSuchElementException("End of stream") - } - if (!gotNext) { - getNext() - } - gotNext = false - nextValue - } - } - iterator - } -} diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala deleted file mode 100644 index b7e4c1c30c..0000000000 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ /dev/null @@ -1,84 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.rdd.BlockRDD -import spark.Partitioner -import spark.rdd.MapPartitionsRDD -import spark.SparkContext._ -import spark.storage.StorageLevel - -class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( - parent: DStream[(K, V)], - updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], - partitioner: Partitioner, - preservePartitioning: Boolean - ) extends DStream[(K, S)](parent.ssc) { - - super.persist(StorageLevel.MEMORY_ONLY_SER) - - override def dependencies = List(parent) - - override def slideTime = parent.slideTime - - override val mustCheckpoint = true - - override def compute(validTime: Time): Option[RDD[(K, S)]] = { - - // Try to get the previous state RDD - getOrCompute(validTime - slideTime) match { - - case Some(prevStateRDD) => { // If previous state RDD exists - - // Try to get the parent RDD - parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual - - // Define the function for the mapPartition operation on cogrouped RDD; - // first map the cogrouped tuple to tuples of required type, - // and then apply the update function - val updateFuncLocal = updateFunc - val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { - val i = iterator.map(t => { - (t._1, t._2._1, t._2._2.headOption) - }) - updateFuncLocal(i) - } - val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) - val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) - //logDebug("Generating state RDD for time " + validTime) - return Some(stateRDD) - } - case None => { // If parent RDD does not exist, then return old state RDD - return Some(prevStateRDD) - } - } - } - - case None => { // If previous session RDD does not exist (first input data) - - // Try to get the parent RDD - parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual - - // Define the function for the mapPartition operation on grouped RDD; - // first map the grouped tuple to tuples of required type, - // and then apply the update function - val updateFuncLocal = updateFunc - val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { - updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None))) - } - - val groupedRDD = parentRDD.groupByKey(partitioner) - val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) - //logDebug("Generating state RDD for time " + validTime + " (first)") - return Some(sessionRDD) - } - case None => { // If parent RDD does not exist, then nothing to do! - //logDebug("Not generating state RDD (no previous state, no parent)") - return None - } - } - } - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 998fea849f..ef73049a81 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -1,10 +1,10 @@ package spark.streaming -import spark.RDD -import spark.Logging -import spark.SparkEnv -import spark.SparkContext +import spark.streaming.dstream._ + +import spark.{RDD, Logging, SparkEnv, SparkContext} import spark.storage.StorageLevel +import spark.util.MetadataCleaner import scala.collection.mutable.Queue @@ -18,7 +18,6 @@ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.flume.source.avro.AvroFlumeEvent import org.apache.hadoop.fs.Path import java.util.UUID -import spark.util.MetadataCleaner /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -126,7 +125,7 @@ class StreamingContext private ( /** * Create an input stream that pulls messages form a Kafka Broker. * - * @param host Zookeper hostname. + * @param hostname Zookeper hostname. * @param port Zookeper port. * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed @@ -319,7 +318,7 @@ object StreamingContext { protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { if (prefix == null) { - time.millis.toString + time.milliseconds.toString } else if (suffix == null || suffix.length ==0) { prefix + "-" + time.milliseconds } else { diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 480d292d7c..2976e5e87b 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -1,6 +1,11 @@ package spark.streaming -case class Time(millis: Long) { +/** + * This class is simple wrapper class that represents time in UTC. + * @param millis Time in UTC long + */ + +case class Time(private val millis: Long) { def < (that: Time): Boolean = (this.millis < that.millis) @@ -15,7 +20,9 @@ case class Time(millis: Long) { def - (that: Time): Time = Time(millis - that.millis) def * (times: Int): Time = Time(millis * times) - + + def / (that: Time): Long = millis / that.millis + def floor(that: Time): Time = { val t = that.millis val m = math.floor(this.millis / t).toLong diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala deleted file mode 100644 index e4d2a634f5..0000000000 --- a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala +++ /dev/null @@ -1,39 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.rdd.UnionRDD -import spark.storage.StorageLevel - - -class WindowedDStream[T: ClassManifest]( - parent: DStream[T], - _windowTime: Time, - _slideTime: Time) - extends DStream[T](parent.ssc) { - - if (!_windowTime.isMultipleOf(parent.slideTime)) - throw new Exception("The window duration of WindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") - - if (!_slideTime.isMultipleOf(parent.slideTime)) - throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") - - parent.persist(StorageLevel.MEMORY_ONLY_SER) - - def windowTime: Time = _windowTime - - override def dependencies = List(parent) - - override def slideTime: Time = _slideTime - - override def parentRememberDuration: Time = rememberDuration + windowTime - - override def compute(validTime: Time): Option[RDD[T]] = { - val currentWindow = Interval(validTime - windowTime + parent.slideTime, validTime) - Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala new file mode 100644 index 0000000000..2e427dadf7 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala @@ -0,0 +1,39 @@ +package spark.streaming.dstream + +import spark.{RDD, Partitioner} +import spark.rdd.CoGroupedRDD +import spark.streaming.{Time, DStream} + +class CoGroupedDStream[K : ClassManifest]( + parents: Seq[DStream[(_, _)]], + partitioner: Partitioner + ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different StreamingContexts") + } + + if (parents.map(_.slideTime).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideTime = parents.head.slideTime + + override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = { + val part = partitioner + val rdds = parents.flatMap(_.getOrCompute(validTime)) + if (rdds.size > 0) { + val q = new CoGroupedRDD[K](rdds, part) + Some(q) + } else { + None + } + } + +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala new file mode 100644 index 0000000000..41c3af4694 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala @@ -0,0 +1,19 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.streaming.{Time, StreamingContext} + +/** + * An input stream that always returns the same RDD on each timestep. Useful for testing. + */ +class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T]) + extends InputDStream[T](ssc_) { + + override def start() {} + + override def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + Some(rdd) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/dstream/DataHandler.scala b/streaming/src/main/scala/spark/streaming/dstream/DataHandler.scala new file mode 100644 index 0000000000..d737ba1ecc --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/DataHandler.scala @@ -0,0 +1,83 @@ +package spark.streaming.dstream + +import java.util.concurrent.ArrayBlockingQueue +import scala.collection.mutable.ArrayBuffer +import spark.Logging +import spark.streaming.util.{RecurringTimer, SystemClock} +import spark.storage.StorageLevel + + +/** + * This is a helper object that manages the data received from the socket. It divides + * the object received into small batches of 100s of milliseconds, pushes them as + * blocks into the block manager and reports the block IDs to the network input + * tracker. It starts two threads, one to periodically start a new batch and prepare + * the previous batch of as a block, the other to push the blocks into the block + * manager. + */ + class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel) + extends Serializable with Logging { + + case class Block(id: String, iterator: Iterator[T], metadata: Any = null) + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockStorageLevel = storageLevel + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def createBlock(blockId: String, iterator: Iterator[T]) : Block = { + new Block(blockId, iterator) + } + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + logInfo("Data handler stopped") + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval) + val newBlock = createBlock(blockId, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + } + } catch { + case ie: InterruptedException => + logInfo("Block interval timer thread interrupted") + case e: Exception => + receiver.stop() + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel) + } + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread interrupted") + case e: Exception => + receiver.stop() + } + } + } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala new file mode 100644 index 0000000000..8cdaff467b --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala @@ -0,0 +1,110 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.rdd.UnionRDD +import spark.streaming.{StreamingContext, Time} + +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} + +import scala.collection.mutable.HashSet + + +class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( + @transient ssc_ : StreamingContext, + directory: String, + filter: PathFilter = FileInputDStream.defaultPathFilter, + newFilesOnly: Boolean = true) + extends InputDStream[(K, V)](ssc_) { + + @transient private var path_ : Path = null + @transient private var fs_ : FileSystem = null + + var lastModTime = 0L + val lastModTimeFiles = new HashSet[String]() + + def path(): Path = { + if (path_ == null) path_ = new Path(directory) + path_ + } + + def fs(): FileSystem = { + if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + fs_ + } + + override def start() { + if (newFilesOnly) { + lastModTime = System.currentTimeMillis() + } else { + lastModTime = 0 + } + } + + override def stop() { } + + /** + * Finds the files that were modified since the last time this method was called and makes + * a union RDD out of them. Note that this maintains the list of files that were processed + * in the latest modification time in the previous call to this method. This is because the + * modification time returned by the FileStatus API seems to return times only at the + * granularity of seconds. Hence, new files may have the same modification time as the + * latest modification time in the previous call to this method and the list of files + * maintained is used to filter the one that have been processed. + */ + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + // Create the filter for selecting new files + val newFilter = new PathFilter() { + var latestModTime = 0L + val latestModTimeFiles = new HashSet[String]() + + def accept(path: Path): Boolean = { + if (!filter.accept(path)) { + return false + } else { + val modTime = fs.getFileStatus(path).getModificationTime() + if (modTime < lastModTime){ + return false + } else if (modTime == lastModTime && lastModTimeFiles.contains(path.toString)) { + return false + } + if (modTime > latestModTime) { + latestModTime = modTime + latestModTimeFiles.clear() + } + latestModTimeFiles += path.toString + return true + } + } + } + + val newFiles = fs.listStatus(path, newFilter) + logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) + if (newFiles.length > 0) { + // Update the modification time and the files processed for that modification time + if (lastModTime != newFilter.latestModTime) { + lastModTime = newFilter.latestModTime + lastModTimeFiles.clear() + } + lastModTimeFiles ++= newFilter.latestModTimeFiles + } + val newRDD = new UnionRDD(ssc.sc, newFiles.map( + file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) + Some(newRDD) + } +} + +object FileInputDStream { + val defaultPathFilter = new PathFilter with Serializable { + def accept(path: Path): Boolean = { + val file = path.getName() + if (file.startsWith(".") || file.endsWith("_tmp")) { + return false + } else { + return true + } + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala new file mode 100644 index 0000000000..1cbb4d536e --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala @@ -0,0 +1,21 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class FilteredDStream[T: ClassManifest]( + parent: DStream[T], + filterFunc: T => Boolean + ) extends DStream[T](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + parent.getOrCompute(validTime).map(_.filter(filterFunc)) + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala new file mode 100644 index 0000000000..11ed8cf317 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -0,0 +1,20 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD +import spark.SparkContext._ + +private[streaming] +class FlatMapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( + parent: DStream[(K, V)], + flatMapValueFunc: V => TraversableOnce[U] + ) extends DStream[(K, U)](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K, U)]] = { + parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc)) + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala new file mode 100644 index 0000000000..a13b4c9ff9 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala @@ -0,0 +1,20 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( + parent: DStream[T], + flatMapFunc: T => Traversable[U] + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala new file mode 100644 index 0000000000..7e988cadf4 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala @@ -0,0 +1,135 @@ +package spark.streaming.dstream + +import spark.streaming.StreamingContext + +import spark.Utils +import spark.storage.StorageLevel + +import org.apache.flume.source.avro.AvroSourceProtocol +import org.apache.flume.source.avro.AvroFlumeEvent +import org.apache.flume.source.avro.Status +import org.apache.avro.ipc.specific.SpecificResponder +import org.apache.avro.ipc.NettyServer + +import scala.collection.JavaConversions._ + +import java.net.InetSocketAddress +import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.nio.ByteBuffer + +class FlumeInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + storageLevel: StorageLevel +) extends NetworkInputDStream[SparkFlumeEvent](ssc_) { + + override def createReceiver(): NetworkReceiver[SparkFlumeEvent] = { + new FlumeReceiver(id, host, port, storageLevel) + } +} + +/** + * A wrapper class for AvroFlumeEvent's with a custom serialization format. + * + * This is necessary because AvroFlumeEvent uses inner data structures + * which are not serializable. + */ +class SparkFlumeEvent() extends Externalizable { + var event : AvroFlumeEvent = new AvroFlumeEvent() + + /* De-serialize from bytes. */ + def readExternal(in: ObjectInput) { + val bodyLength = in.readInt() + val bodyBuff = new Array[Byte](bodyLength) + in.read(bodyBuff) + + val numHeaders = in.readInt() + val headers = new java.util.HashMap[CharSequence, CharSequence] + + for (i <- 0 until numHeaders) { + val keyLength = in.readInt() + val keyBuff = new Array[Byte](keyLength) + in.read(keyBuff) + val key : String = Utils.deserialize(keyBuff) + + val valLength = in.readInt() + val valBuff = new Array[Byte](valLength) + in.read(valBuff) + val value : String = Utils.deserialize(valBuff) + + headers.put(key, value) + } + + event.setBody(ByteBuffer.wrap(bodyBuff)) + event.setHeaders(headers) + } + + /* Serialize to bytes. */ + def writeExternal(out: ObjectOutput) { + val body = event.getBody.array() + out.writeInt(body.length) + out.write(body) + + val numHeaders = event.getHeaders.size() + out.writeInt(numHeaders) + for ((k, v) <- event.getHeaders) { + val keyBuff = Utils.serialize(k.toString) + out.writeInt(keyBuff.length) + out.write(keyBuff) + val valBuff = Utils.serialize(v.toString) + out.writeInt(valBuff.length) + out.write(valBuff) + } + } +} + +private[streaming] object SparkFlumeEvent { + def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { + val event = new SparkFlumeEvent + event.event = in + event + } +} + +/** A simple server that implements Flume's Avro protocol. */ +class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { + override def append(event : AvroFlumeEvent) : Status = { + receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event) + Status.OK + } + + override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { + events.foreach (event => + receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event)) + Status.OK + } +} + +/** A NetworkReceiver which listens for events using the + * Flume Avro interface.*/ +class FlumeReceiver( + streamId: Int, + host: String, + port: Int, + storageLevel: StorageLevel + ) extends NetworkReceiver[SparkFlumeEvent](streamId) { + + lazy val dataHandler = new DataHandler(this, storageLevel) + + protected override def onStart() { + val responder = new SpecificResponder( + classOf[AvroSourceProtocol], new FlumeEventServer(this)); + val server = new NettyServer(responder, new InetSocketAddress(host, port)); + dataHandler.start() + server.start() + logInfo("Flume receiver started") + } + + protected override def onStop() { + dataHandler.stop() + logInfo("Flume receiver stopped") + } + + override def getLocationPreference = Some(host) +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala new file mode 100644 index 0000000000..41c629a225 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala @@ -0,0 +1,28 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.streaming.{DStream, Job, Time} + +private[streaming] +class ForEachDStream[T: ClassManifest] ( + parent: DStream[T], + foreachFunc: (RDD[T], Time) => Unit + ) extends DStream[Unit](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Unit]] = None + + override def generateJob(time: Time): Option[Job] = { + parent.getOrCompute(time) match { + case Some(rdd) => + val jobFunc = () => { + foreachFunc(rdd, time) + } + Some(new Job(time, jobFunc)) + case None => None + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala new file mode 100644 index 0000000000..92ea503cae --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala @@ -0,0 +1,17 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class GlommedDStream[T: ClassManifest](parent: DStream[T]) + extends DStream[Array[T]](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Array[T]]] = { + parent.getOrCompute(validTime).map(_.glom()) + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala new file mode 100644 index 0000000000..4959c66b06 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala @@ -0,0 +1,19 @@ +package spark.streaming.dstream + +import spark.streaming.{StreamingContext, DStream} + +abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) + extends DStream[T](ssc_) { + + override def dependencies = List() + + override def slideTime = { + if (ssc == null) throw new Exception("ssc is null") + if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null") + ssc.graph.batchDuration + } + + def start() + + def stop() +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala new file mode 100644 index 0000000000..a46721af2f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -0,0 +1,197 @@ +package spark.streaming.dstream + +import spark.Logging +import spark.storage.StorageLevel +import spark.streaming.{Time, DStreamCheckpointData, StreamingContext} + +import java.util.Properties +import java.util.concurrent.Executors + +import kafka.consumer._ +import kafka.message.{Message, MessageSet, MessageAndMetadata} +import kafka.serializer.StringDecoder +import kafka.utils.{Utils, ZKGroupTopicDirs} +import kafka.utils.ZkUtils._ + +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ + + +// Key for a specific Kafka Partition: (broker, topic, group, part) +case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) +// NOT USED - Originally intended for fault-tolerance +// Metadata for a Kafka Stream that it sent to the Master +case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) +// NOT USED - Originally intended for fault-tolerance +// Checkpoint data specific to a KafkaInputDstream +case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], + savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds) + +/** + * Input stream that pulls messages from a Kafka Broker. + * + * @param host Zookeper hostname. + * @param port Zookeper port. + * @param groupId The group id for this consumer. + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param initialOffsets Optional initial offsets for each of the partitions to consume. + * By default the value is pulled from zookeper. + * @param storageLevel RDD storage level. + */ +class KafkaInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + groupId: String, + topics: Map[String, Int], + initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_ ) with Logging { + + // Metadata that keeps track of which messages have already been consumed. + var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]() + + /* NOT USED - Originally intended for fault-tolerance + + // In case of a failure, the offets for a particular timestamp will be restored. + @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null + + + override protected[streaming] def addMetadata(metadata: Any) { + metadata match { + case x : KafkaInputDStreamMetadata => + savedOffsets(x.timestamp) = x.data + // TOOD: Remove logging + logInfo("New saved Offsets: " + savedOffsets) + case _ => logInfo("Received unknown metadata: " + metadata.toString) + } + } + + override protected[streaming] def updateCheckpointData(currentTime: Time) { + super.updateCheckpointData(currentTime) + if(savedOffsets.size > 0) { + // Find the offets that were stored before the checkpoint was initiated + val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last + val latestOffsets = savedOffsets(key) + logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString) + checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets) + // TODO: This may throw out offsets that are created after the checkpoint, + // but it's unlikely we'll need them. + savedOffsets.clear() + } + } + + override protected[streaming] def restoreCheckpointData() { + super.restoreCheckpointData() + logInfo("Restoring KafkaDStream checkpoint data.") + checkpointData match { + case x : KafkaDStreamCheckpointData => + restoredOffsets = x.savedOffsets + logInfo("Restored KafkaDStream offsets: " + savedOffsets) + } + } */ + + def createReceiver(): NetworkReceiver[T] = { + new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) + .asInstanceOf[NetworkReceiver[T]] + } +} + +class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, + topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) { + + // Timeout for establishing a connection to Zookeper in ms. + val ZK_TIMEOUT = 10000 + + // Handles pushing data into the BlockManager + lazy protected val dataHandler = new DataHandler(this, storageLevel) + // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset + lazy val offsets = HashMap[KafkaPartitionKey, Long]() + // Connection to Kafka + var consumerConnector : ZookeeperConsumerConnector = null + + def onStop() { + dataHandler.stop() + } + + def onStart() { + + // Starting the DataHandler that buffers blocks and pushes them into them BlockManager + dataHandler.start() + + // In case we are using multiple Threads to handle Kafka Messages + val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) + + val zooKeeperEndPoint = host + ":" + port + logInfo("Starting Kafka Consumer Stream with group: " + groupId) + logInfo("Initial offsets: " + initialOffsets.toString) + + // Zookeper connection properties + val props = new Properties() + props.put("zk.connect", zooKeeperEndPoint) + props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) + props.put("groupid", groupId) + + // Create the connection to the cluster + logInfo("Connecting to Zookeper: " + zooKeeperEndPoint) + val consumerConfig = new ConsumerConfig(props) + consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] + logInfo("Connected to " + zooKeeperEndPoint) + + // Reset the Kafka offsets in case we are recovering from a failure + resetOffsets(initialOffsets) + + // Create Threads for each Topic/Message Stream we are listening + val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) + + // Start the messages handler for each partition + topicMessageStreams.values.foreach { streams => + streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } + } + + } + + // Overwrites the offets in Zookeper. + private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) { + offsets.foreach { case(key, offset) => + val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) + val partitionName = key.brokerId + "-" + key.partId + updatePersistentPath(consumerConnector.zkClient, + topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString) + } + } + + // Handles Kafka Messages + private class MessageHandler(stream: KafkaStream[String]) extends Runnable { + def run() { + logInfo("Starting MessageHandler.") + stream.takeWhile { msgAndMetadata => + dataHandler += msgAndMetadata.message + + // Updating the offet. The key is (broker, topic, group, partition). + val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic, + groupId, msgAndMetadata.topicInfo.partition.partId) + val offset = msgAndMetadata.topicInfo.getConsumeOffset + offsets.put(key, offset) + // logInfo("Handled message: " + (key, offset).toString) + + // Keep on handling messages + true + } + } + } + + // NOT USED - Originally intended for fault-tolerance + // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) + // extends DataHandler[Any](receiver, storageLevel) { + + // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { + // // Creates a new Block with Kafka-specific Metadata + // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) + // } + + // } + +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala new file mode 100644 index 0000000000..daf78c6893 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala @@ -0,0 +1,21 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( + parent: DStream[T], + mapPartFunc: Iterator[T] => Iterator[U], + preservePartitioning: Boolean + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning)) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala new file mode 100644 index 0000000000..689caeef0e --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala @@ -0,0 +1,21 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD +import spark.SparkContext._ + +private[streaming] +class MapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( + parent: DStream[(K, V)], + mapValueFunc: V => U + ) extends DStream[(K, U)](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K, U)]] = { + parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc)) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala new file mode 100644 index 0000000000..786b9966f2 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala @@ -0,0 +1,20 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class MappedDStream[T: ClassManifest, U: ClassManifest] ( + parent: DStream[T], + mapFunc: T => U + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.map[U](mapFunc)) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala new file mode 100644 index 0000000000..41276da8bb --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala @@ -0,0 +1,157 @@ +package spark.streaming.dstream + +import spark.streaming.{Time, StreamingContext, AddBlocks, RegisterReceiver, DeregisterReceiver} + +import spark.{Logging, SparkEnv, RDD} +import spark.rdd.BlockRDD +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer + +import java.nio.ByteBuffer + +import akka.actor.{Props, Actor} +import akka.pattern.ask +import akka.dispatch.Await +import akka.util.duration._ + +abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext) + extends InputDStream[T](ssc_) { + + // This is an unique identifier that is used to match the network receiver with the + // corresponding network input stream. + val id = ssc.getNewNetworkStreamId() + + /** + * This method creates the receiver object that will be sent to the workers + * to receive data. This method needs to defined by any specific implementation + * of a NetworkInputDStream. + */ + def createReceiver(): NetworkReceiver[T] + + // Nothing to start or stop as both taken care of by the NetworkInputTracker. + def start() {} + + def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) + Some(new BlockRDD[T](ssc.sc, blockIds)) + } +} + + +private[streaming] sealed trait NetworkReceiverMessage +private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage +private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage +private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage + +abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging { + + initLogging() + + lazy protected val env = SparkEnv.get + + lazy protected val actor = env.actorSystem.actorOf( + Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) + + lazy protected val receivingThread = Thread.currentThread() + + /** This method will be called to start receiving data. */ + protected def onStart() + + /** This method will be called to stop receiving data. */ + protected def onStop() + + /** This method conveys a placement preference (hostname) for this receiver. */ + def getLocationPreference() : Option[String] = None + + /** + * This method starts the receiver. First is accesses all the lazy members to + * materialize them. Then it calls the user-defined onStart() method to start + * other threads, etc required to receiver the data. + */ + def start() { + try { + // Access the lazy vals to materialize them + env + actor + receivingThread + + // Call user-defined onStart() + onStart() + } catch { + case ie: InterruptedException => + logInfo("Receiving thread interrupted") + //println("Receiving thread interrupted") + case e: Exception => + stopOnError(e) + } + } + + /** + * This method stops the receiver. First it interrupts the main receiving thread, + * that is, the thread that called receiver.start(). Then it calls the user-defined + * onStop() method to stop other threads and/or do cleanup. + */ + def stop() { + receivingThread.interrupt() + onStop() + //TODO: terminate the actor + } + + /** + * This method stops the receiver and reports to exception to the tracker. + * This should be called whenever an exception has happened on any thread + * of the receiver. + */ + protected def stopOnError(e: Exception) { + logError("Error receiving data", e) + stop() + actor ! ReportError(e.toString) + } + + + /** + * This method pushes a block (as iterator of values) into the block manager. + */ + def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) { + val buffer = new ArrayBuffer[T] ++ iterator + env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level) + + actor ! ReportBlock(blockId, metadata) + } + + /** + * This method pushes a block (as bytes) into the block manager. + */ + def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { + env.blockManager.putBytes(blockId, bytes, level) + actor ! ReportBlock(blockId, metadata) + } + + /** A helper actor that communicates with the NetworkInputTracker */ + private class NetworkReceiverActor extends Actor { + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) + val tracker = env.actorSystem.actorFor(url) + val timeout = 5.seconds + + override def preStart() { + val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + override def receive() = { + case ReportBlock(blockId, metadata) => + tracker ! AddBlocks(streamId, Array(blockId), metadata) + case ReportError(msg) => + tracker ! DeregisterReceiver(streamId, msg) + case StopReceiver(msg) => + stop() + tracker ! DeregisterReceiver(streamId, msg) + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala new file mode 100644 index 0000000000..024bf3bea4 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala @@ -0,0 +1,41 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.rdd.UnionRDD + +import scala.collection.mutable.Queue +import scala.collection.mutable.ArrayBuffer +import spark.streaming.{Time, StreamingContext} + +class QueueInputDStream[T: ClassManifest]( + @transient ssc: StreamingContext, + val queue: Queue[RDD[T]], + oneAtATime: Boolean, + defaultRDD: RDD[T] + ) extends InputDStream[T](ssc) { + + override def start() { } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[T]] = { + val buffer = new ArrayBuffer[RDD[T]]() + if (oneAtATime && queue.size > 0) { + buffer += queue.dequeue() + } else { + buffer ++= queue + } + if (buffer.size > 0) { + if (oneAtATime) { + Some(buffer.first) + } else { + Some(new UnionRDD(ssc.sc, buffer.toSeq)) + } + } else if (defaultRDD != null) { + Some(defaultRDD) + } else { + None + } + } + +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala new file mode 100644 index 0000000000..996cc7dea8 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala @@ -0,0 +1,88 @@ +package spark.streaming.dstream + +import spark.{DaemonThread, Logging} +import spark.storage.StorageLevel +import spark.streaming.StreamingContext + +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, SocketChannel} +import java.io.EOFException +import java.util.concurrent.ArrayBlockingQueue + + +/** + * An input stream that reads blocks of serialized objects from a given network address. + * The blocks will be inserted directly into the block store. This is the fastest way to get + * data into Spark Streaming, though it requires the sender to batch data and serialize it + * in the format that the system is configured with. + */ +class RawInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_ ) with Logging { + + def createReceiver(): NetworkReceiver[T] = { + new RawNetworkReceiver(id, host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] + } +} + +class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel) + extends NetworkReceiver[Any](streamId) { + + var blockPushingThread: Thread = null + + override def getLocationPreference = None + + def onStart() { + // Open a socket to the target address and keep reading from it + logInfo("Connecting to " + host + ":" + port) + val channel = SocketChannel.open() + channel.configureBlocking(true) + channel.connect(new InetSocketAddress(host, port)) + logInfo("Connected to " + host + ":" + port) + + val queue = new ArrayBlockingQueue[ByteBuffer](2) + + blockPushingThread = new DaemonThread { + override def run() { + var nextBlockNumber = 0 + while (true) { + val buffer = queue.take() + val blockId = "input-" + streamId + "-" + nextBlockNumber + nextBlockNumber += 1 + pushBlock(blockId, buffer, null, storageLevel) + } + } + } + blockPushingThread.start() + + val lengthBuffer = ByteBuffer.allocate(4) + while (true) { + lengthBuffer.clear() + readFully(channel, lengthBuffer) + lengthBuffer.flip() + val length = lengthBuffer.getInt() + val dataBuffer = ByteBuffer.allocate(length) + readFully(channel, dataBuffer) + dataBuffer.flip() + logInfo("Read a block with " + length + " bytes") + queue.put(dataBuffer) + } + } + + def onStop() { + if (blockPushingThread != null) blockPushingThread.interrupt() + } + + /** Read a buffer fully from a given Channel */ + private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { + while (dest.position < dest.limit) { + if (channel.read(dest) == -1) { + throw new EOFException("End of channel") + } + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala new file mode 100644 index 0000000000..2686de14d2 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -0,0 +1,148 @@ +package spark.streaming.dstream + +import spark.streaming.StreamingContext._ + +import spark.RDD +import spark.rdd.CoGroupedRDD +import spark.Partitioner +import spark.SparkContext._ +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer +import spark.streaming.{Interval, Time, DStream} + +class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( + parent: DStream[(K, V)], + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + _windowTime: Time, + _slideTime: Time, + partitioner: Partitioner + ) extends DStream[(K,V)](parent.ssc) { + + assert(_windowTime.isMultipleOf(parent.slideTime), + "The window duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) + + assert(_slideTime.isMultipleOf(parent.slideTime), + "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) + + // Reduce each batch of data using reduceByKey which will be further reduced by window + // by ReducedWindowedDStream + val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + + // Persist RDDs to memory by default as these RDDs are going to be reused. + super.persist(StorageLevel.MEMORY_ONLY_SER) + reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) + + def windowTime: Time = _windowTime + + override def dependencies = List(reducedStream) + + override def slideTime: Time = _slideTime + + override val mustCheckpoint = true + + override def parentRememberDuration: Time = rememberDuration + windowTime + + override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + super.persist(storageLevel) + reducedStream.persist(storageLevel) + this + } + + override def checkpoint(interval: Time): DStream[(K, V)] = { + super.checkpoint(interval) + //reducedStream.checkpoint(interval) + this + } + + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val reduceF = reduceFunc + val invReduceF = invReduceFunc + + val currentTime = validTime + val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime) + val previousWindow = currentWindow - slideTime + + logDebug("Window time = " + windowTime) + logDebug("Slide time = " + slideTime) + logDebug("ZeroTime = " + zeroTime) + logDebug("Current window = " + currentWindow) + logDebug("Previous window = " + previousWindow) + + // _____________________________ + // | previous window _________|___________________ + // |___________________| current window | --------------> Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old RDDs new RDDs + // + + // Get the RDDs of the reduced values in "old time steps" + val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideTime) + logDebug("# old RDDs = " + oldRDDs.size) + + // Get the RDDs of the reduced values in "new time steps" + val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideTime, currentWindow.endTime) + logDebug("# new RDDs = " + newRDDs.size) + + // Get the RDD of the reduced value of the previous window + val previousWindowRDD = getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) + + // Make the list of RDDs that needs to cogrouped together for reducing their reduced values + val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs + + // Cogroup the reduced RDDs and merge the reduced values + val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) + //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ + + val numOldValues = oldRDDs.size + val numNewValues = newRDDs.size + + val mergeValues = (seqOfValues: Seq[Seq[V]]) => { + if (seqOfValues.size != 1 + numOldValues + numNewValues) { + throw new Exception("Unexpected number of sequences of reduced values") + } + // Getting reduced values "old time steps" that will be removed from current window + val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) + // Getting reduced values "new time steps" + val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + if (seqOfValues(0).isEmpty) { + // If previous window's reduce value does not exist, then at least new values should exist + if (newValues.isEmpty) { + throw new Exception("Neither previous window has value for key, nor new values found. " + + "Are you sure your key class hashes consistently?") + } + // Reduce the new values + newValues.reduce(reduceF) // return + } else { + // Get the previous window's reduced value + var tempValue = seqOfValues(0).head + // If old values exists, then inverse reduce then from previous value + if (!oldValues.isEmpty) { + tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) + } + // If new values exists, then reduce them with previous value + if (!newValues.isEmpty) { + tempValue = reduceF(tempValue, newValues.reduce(reduceF)) + } + tempValue // return + } + } + + val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) + + Some(mergedValuesRDD) + } + + +} + + diff --git a/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala new file mode 100644 index 0000000000..6854bbe665 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala @@ -0,0 +1,27 @@ +package spark.streaming.dstream + +import spark.{RDD, Partitioner} +import spark.SparkContext._ +import spark.streaming.{DStream, Time} + +private[streaming] +class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( + parent: DStream[(K,V)], + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + partitioner: Partitioner + ) extends DStream [(K,C)] (parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K,C)]] = { + parent.getOrCompute(validTime) match { + case Some(rdd) => + Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner)) + case None => None + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala new file mode 100644 index 0000000000..af5b73ae8d --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala @@ -0,0 +1,103 @@ +package spark.streaming.dstream + +import spark.streaming.StreamingContext +import spark.storage.StorageLevel + +import java.io._ +import java.net.Socket + +class SocketInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_) { + + def createReceiver(): NetworkReceiver[T] = { + new SocketReceiver(id, host, port, bytesToObjects, storageLevel) + } +} + + +class SocketReceiver[T: ClassManifest]( + streamId: Int, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkReceiver[T](streamId) { + + lazy protected val dataHandler = new DataHandler(this, storageLevel) + + override def getLocationPreference = None + + protected def onStart() { + logInfo("Connecting to " + host + ":" + port) + val socket = new Socket(host, port) + logInfo("Connected to " + host + ":" + port) + dataHandler.start() + val iterator = bytesToObjects(socket.getInputStream()) + while(iterator.hasNext) { + val obj = iterator.next + dataHandler += obj + } + } + + protected def onStop() { + dataHandler.stop() + } + +} + + +object SocketReceiver { + + /** + * This methods translates the data from an inputstream (say, from a socket) + * to '\n' delimited strings and returns an iterator to access the strings. + */ + def bytesToLines(inputStream: InputStream): Iterator[String] = { + val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")) + + val iterator = new Iterator[String] { + var gotNext = false + var finished = false + var nextValue: String = null + + private def getNext() { + try { + nextValue = dataInputStream.readLine() + if (nextValue == null) { + finished = true + } + } + gotNext = true + } + + override def hasNext: Boolean = { + if (!finished) { + if (!gotNext) { + getNext() + if (finished) { + dataInputStream.close() + } + } + } + !finished + } + + override def next(): String = { + if (finished) { + throw new NoSuchElementException("End of stream") + } + if (!gotNext) { + getNext() + } + gotNext = false + nextValue + } + } + iterator + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala new file mode 100644 index 0000000000..6e190b5564 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala @@ -0,0 +1,83 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.Partitioner +import spark.SparkContext._ +import spark.storage.StorageLevel +import spark.streaming.{Time, DStream} + +class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( + parent: DStream[(K, V)], + updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], + partitioner: Partitioner, + preservePartitioning: Boolean + ) extends DStream[(K, S)](parent.ssc) { + + super.persist(StorageLevel.MEMORY_ONLY_SER) + + override def dependencies = List(parent) + + override def slideTime = parent.slideTime + + override val mustCheckpoint = true + + override def compute(validTime: Time): Option[RDD[(K, S)]] = { + + // Try to get the previous state RDD + getOrCompute(validTime - slideTime) match { + + case Some(prevStateRDD) => { // If previous state RDD exists + + // Try to get the parent RDD + parent.getOrCompute(validTime) match { + case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on cogrouped RDD; + // first map the cogrouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val i = iterator.map(t => { + (t._1, t._2._1, t._2._2.headOption) + }) + updateFuncLocal(i) + } + val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) + val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) + //logDebug("Generating state RDD for time " + validTime) + return Some(stateRDD) + } + case None => { // If parent RDD does not exist, then return old state RDD + return Some(prevStateRDD) + } + } + } + + case None => { // If previous session RDD does not exist (first input data) + + // Try to get the parent RDD + parent.getOrCompute(validTime) match { + case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on grouped RDD; + // first map the grouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None))) + } + + val groupedRDD = parentRDD.groupByKey(partitioner) + val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) + //logDebug("Generating state RDD for time " + validTime + " (first)") + return Some(sessionRDD) + } + case None => { // If parent RDD does not exist, then nothing to do! + //logDebug("Not generating state RDD (no previous state, no parent)") + return None + } + } + } + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala new file mode 100644 index 0000000000..0337579514 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala @@ -0,0 +1,19 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.streaming.{DStream, Time} + +private[streaming] +class TransformedDStream[T: ClassManifest, U: ClassManifest] ( + parent: DStream[T], + transformFunc: (RDD[T], Time) => RDD[U] + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(transformFunc(_, validTime)) + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala new file mode 100644 index 0000000000..f1efb2ae72 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala @@ -0,0 +1,39 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD +import collection.mutable.ArrayBuffer +import spark.rdd.UnionRDD + +class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) + extends DStream[T](parents.head.ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different StreamingContexts") + } + + if (parents.map(_.slideTime).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideTime: Time = parents.head.slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val rdds = new ArrayBuffer[RDD[T]]() + parents.map(_.getOrCompute(validTime)).foreach(_ match { + case Some(rdd) => rdds += rdd + case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) + }) + if (rdds.size > 0) { + Some(new UnionRDD(ssc.sc, rdds)) + } else { + None + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala new file mode 100644 index 0000000000..4b2621c497 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala @@ -0,0 +1,40 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.rdd.UnionRDD +import spark.storage.StorageLevel +import spark.streaming.{Interval, Time, DStream} + + +class WindowedDStream[T: ClassManifest]( + parent: DStream[T], + _windowTime: Time, + _slideTime: Time) + extends DStream[T](parent.ssc) { + + if (!_windowTime.isMultipleOf(parent.slideTime)) + throw new Exception("The window duration of WindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + if (!_slideTime.isMultipleOf(parent.slideTime)) + throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + parent.persist(StorageLevel.MEMORY_ONLY_SER) + + def windowTime: Time = _windowTime + + override def dependencies = List(parent) + + override def slideTime: Time = _slideTime + + override def parentRememberDuration: Time = rememberDuration + windowTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val currentWindow = Interval(validTime - windowTime + parent.slideTime, validTime) + Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala index 7c4ee3b34c..dfaaf03f03 100644 --- a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -25,7 +25,7 @@ object GrepRaw { val rawStreams = (1 to numStreams).map(_ => ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray - val union = new UnionDStream(rawStreams) + val union = ssc.union(rawStreams) union.filter(_.contains("Alice")).count().foreach(r => println("Grep count: " + r.collect().mkString)) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index 182dfd8a52..338834bc3c 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -34,7 +34,7 @@ object TopKWordCountRaw { val lines = (1 to numStreams).map(_ => { ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) }) - val union = new UnionDStream(lines.toArray) + val union = ssc.union(lines) val counts = union.mapPartitions(splitAndCountPartitions) val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index 9bcd30f4d7..d93335a8ce 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -33,7 +33,7 @@ object WordCountRaw { val lines = (1 to numStreams).map(_ => { ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) }) - val union = new UnionDStream(lines.toArray) + val union = ssc.union(lines) val counts = union.mapPartitions(splitAndCountPartitions) val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) windowedCounts.foreach(r => println("# unique words = " + r.count())) diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala deleted file mode 100644 index 7c642d4802..0000000000 --- a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala +++ /dev/null @@ -1,193 +0,0 @@ -package spark.streaming - -import java.util.Properties -import java.util.concurrent.Executors -import kafka.consumer._ -import kafka.message.{Message, MessageSet, MessageAndMetadata} -import kafka.serializer.StringDecoder -import kafka.utils.{Utils, ZKGroupTopicDirs} -import kafka.utils.ZkUtils._ -import scala.collection.mutable.HashMap -import scala.collection.JavaConversions._ -import spark._ -import spark.RDD -import spark.storage.StorageLevel - -// Key for a specific Kafka Partition: (broker, topic, group, part) -case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) -// NOT USED - Originally intended for fault-tolerance -// Metadata for a Kafka Stream that it sent to the Master -case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) -// NOT USED - Originally intended for fault-tolerance -// Checkpoint data specific to a KafkaInputDstream -case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], - savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds) - -/** - * Input stream that pulls messages form a Kafka Broker. - * - * @param host Zookeper hostname. - * @param port Zookeper port. - * @param groupId The group id for this consumer. - * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. - * @param storageLevel RDD storage level. - */ -class KafkaInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - groupId: String, - topics: Map[String, Int], - initialOffsets: Map[KafkaPartitionKey, Long], - storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_ ) with Logging { - - // Metadata that keeps track of which messages have already been consumed. - var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]() - - /* NOT USED - Originally intended for fault-tolerance - - // In case of a failure, the offets for a particular timestamp will be restored. - @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null - - - override protected[streaming] def addMetadata(metadata: Any) { - metadata match { - case x : KafkaInputDStreamMetadata => - savedOffsets(x.timestamp) = x.data - // TOOD: Remove logging - logInfo("New saved Offsets: " + savedOffsets) - case _ => logInfo("Received unknown metadata: " + metadata.toString) - } - } - - override protected[streaming] def updateCheckpointData(currentTime: Time) { - super.updateCheckpointData(currentTime) - if(savedOffsets.size > 0) { - // Find the offets that were stored before the checkpoint was initiated - val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last - val latestOffsets = savedOffsets(key) - logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString) - checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets) - // TODO: This may throw out offsets that are created after the checkpoint, - // but it's unlikely we'll need them. - savedOffsets.clear() - } - } - - override protected[streaming] def restoreCheckpointData() { - super.restoreCheckpointData() - logInfo("Restoring KafkaDStream checkpoint data.") - checkpointData match { - case x : KafkaDStreamCheckpointData => - restoredOffsets = x.savedOffsets - logInfo("Restored KafkaDStream offsets: " + savedOffsets) - } - } */ - - def createReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) - .asInstanceOf[NetworkReceiver[T]] - } -} - -class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, - topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], - storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) { - - // Timeout for establishing a connection to Zookeper in ms. - val ZK_TIMEOUT = 10000 - - // Handles pushing data into the BlockManager - lazy protected val dataHandler = new DataHandler(this, storageLevel) - // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset - lazy val offsets = HashMap[KafkaPartitionKey, Long]() - // Connection to Kafka - var consumerConnector : ZookeeperConsumerConnector = null - - def onStop() { - dataHandler.stop() - } - - def onStart() { - - // Starting the DataHandler that buffers blocks and pushes them into them BlockManager - dataHandler.start() - - // In case we are using multiple Threads to handle Kafka Messages - val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) - - val zooKeeperEndPoint = host + ":" + port - logInfo("Starting Kafka Consumer Stream with group: " + groupId) - logInfo("Initial offsets: " + initialOffsets.toString) - - // Zookeper connection properties - val props = new Properties() - props.put("zk.connect", zooKeeperEndPoint) - props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) - props.put("groupid", groupId) - - // Create the connection to the cluster - logInfo("Connecting to Zookeper: " + zooKeeperEndPoint) - val consumerConfig = new ConsumerConfig(props) - consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] - logInfo("Connected to " + zooKeeperEndPoint) - - // Reset the Kafka offsets in case we are recovering from a failure - resetOffsets(initialOffsets) - - // Create Threads for each Topic/Message Stream we are listening - val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) - - // Start the messages handler for each partition - topicMessageStreams.values.foreach { streams => - streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } - } - - } - - // Overwrites the offets in Zookeper. - private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) { - offsets.foreach { case(key, offset) => - val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) - val partitionName = key.brokerId + "-" + key.partId - updatePersistentPath(consumerConnector.zkClient, - topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString) - } - } - - // Handles Kafka Messages - private class MessageHandler(stream: KafkaStream[String]) extends Runnable { - def run() { - logInfo("Starting MessageHandler.") - stream.takeWhile { msgAndMetadata => - dataHandler += msgAndMetadata.message - - // Updating the offet. The key is (broker, topic, group, partition). - val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic, - groupId, msgAndMetadata.topicInfo.partition.partId) - val offset = msgAndMetadata.topicInfo.getConsumeOffset - offsets.put(key, offset) - // logInfo("Handled message: " + (key, offset).toString) - - // Keep on handling messages - true - } - } - } - - // NOT USED - Originally intended for fault-tolerance - // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) - // extends DataHandler[Any](receiver, storageLevel) { - - // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { - // // Creates a new Block with Kafka-specific Metadata - // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) - // } - - // } - -} diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 0d82b2f1ea..920388bba9 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -42,7 +42,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val stateStreamCheckpointInterval = Seconds(1) // this ensure checkpointing occurs at least once - val firstNumBatches = (stateStreamCheckpointInterval.millis / batchDuration.millis) * 2 + val firstNumBatches = (stateStreamCheckpointInterval / batchDuration) * 2 val secondNumBatches = firstNumBatches // Setup the streams diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index 5b414117fc..4aa428bf64 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -133,7 +133,7 @@ class FailureSuite extends TestSuiteBase with BeforeAndAfter { // Get the output buffer val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] val output = outputStream.output - val waitTime = (batchDuration.millis * (numBatches.toDouble + 0.5)).toLong + val waitTime = (batchDuration.milliseconds * (numBatches.toDouble + 0.5)).toLong val startTime = System.currentTimeMillis() try { diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index ed9a659092..76b528bec3 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -1,5 +1,6 @@ package spark.streaming +import dstream.SparkFlumeEvent import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket} import java.io.{File, BufferedWriter, OutputStreamWriter} import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index a44f738957..28bdd53c3c 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -1,12 +1,16 @@ package spark.streaming +import spark.streaming.dstream.{InputDStream, ForEachDStream} +import spark.streaming.util.ManualClock + import spark.{RDD, Logging} -import util.ManualClock + import collection.mutable.ArrayBuffer -import org.scalatest.FunSuite import collection.mutable.SynchronizedBuffer + import java.io.{ObjectInputStream, IOException} +import org.scalatest.FunSuite /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, @@ -70,6 +74,10 @@ trait TestSuiteBase extends FunSuite with Logging { def actuallyWait = false + /** + * Set up required DStreams to test the DStream operation using the two sequences + * of input collections. + */ def setupStreams[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V] @@ -90,6 +98,10 @@ trait TestSuiteBase extends FunSuite with Logging { ssc } + /** + * Set up required DStreams to test the binary operation using the sequence + * of input collections. + */ def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest]( input1: Seq[Seq[U]], input2: Seq[Seq[V]], @@ -173,6 +185,11 @@ trait TestSuiteBase extends FunSuite with Logging { output } + /** + * Verify whether the output values after running a DStream operation + * is same as the expected output values, by comparing the output + * collections either as lists (order matters) or sets (order does not matter) + */ def verifyOutput[V: ClassManifest]( output: Seq[Seq[V]], expectedOutput: Seq[Seq[V]], @@ -199,6 +216,10 @@ trait TestSuiteBase extends FunSuite with Logging { logInfo("Output verified successfully") } + /** + * Test unary DStream operation with a list of inputs, with number of + * batches to run same as the number of expected output values + */ def testOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], @@ -208,6 +229,15 @@ trait TestSuiteBase extends FunSuite with Logging { testOperation[U, V](input, operation, expectedOutput, -1, useSet) } + /** + * Test unary DStream operation with a list of inputs + * @param input Sequence of input collections + * @param operation Binary DStream operation to be applied to the 2 inputs + * @param expectedOutput Sequence of expected output collections + * @param numBatches Number of batches to run the operation for + * @param useSet Compare the output values with the expected output values + * as sets (order matters) or as lists (order does not matter) + */ def testOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], @@ -221,6 +251,10 @@ trait TestSuiteBase extends FunSuite with Logging { verifyOutput[V](output, expectedOutput, useSet) } + /** + * Test binary DStream operation with two lists of inputs, with number of + * batches to run same as the number of expected output values + */ def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( input1: Seq[Seq[U]], input2: Seq[Seq[V]], @@ -231,6 +265,16 @@ trait TestSuiteBase extends FunSuite with Logging { testOperation[U, V, W](input1, input2, operation, expectedOutput, -1, useSet) } + /** + * Test binary DStream operation with two lists of inputs + * @param input1 First sequence of input collections + * @param input2 Second sequence of input collections + * @param operation Binary DStream operation to be applied to the 2 inputs + * @param expectedOutput Sequence of expected output collections + * @param numBatches Number of batches to run the operation for + * @param useSet Compare the output values with the expected output values + * as sets (order matters) or as lists (order does not matter) + */ def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( input1: Seq[Seq[U]], input2: Seq[Seq[V]], diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index 3e20e16708..4bc5229465 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -209,7 +209,7 @@ class WindowOperationsSuite extends TestSuiteBase { val expectedOutput = bigGroupByOutput.map(_.map(x => (x._1, x._2.toSet))) val windowTime = Seconds(2) val slideTime = Seconds(1) - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.groupByKeyAndWindow(windowTime, slideTime) .map(x => (x._1, x._2.toSet)) @@ -223,7 +223,7 @@ class WindowOperationsSuite extends TestSuiteBase { val expectedOutput = Seq( Seq(1), Seq(2), Seq(3), Seq(3), Seq(1), Seq(0)) val windowTime = Seconds(2) val slideTime = Seconds(1) - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[Int]) => s.countByWindow(windowTime, slideTime) testOperation(input, operation, expectedOutput, numBatches, true) } @@ -233,7 +233,7 @@ class WindowOperationsSuite extends TestSuiteBase { val expectedOutput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 2)), Seq(("a", 1), ("b", 3))) val windowTime = Seconds(2) val slideTime = Seconds(1) - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.countByKeyAndWindow(windowTime, slideTime).map(x => (x._1, x._2.toInt)) } @@ -251,7 +251,7 @@ class WindowOperationsSuite extends TestSuiteBase { slideTime: Time = Seconds(1) ) { test("window - " + name) { - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[Int]) => s.window(windowTime, slideTime) testOperation(input, operation, expectedOutput, numBatches, true) } @@ -265,7 +265,7 @@ class WindowOperationsSuite extends TestSuiteBase { slideTime: Time = Seconds(1) ) { test("reduceByKeyAndWindow - " + name) { - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist() } @@ -281,7 +281,7 @@ class WindowOperationsSuite extends TestSuiteBase { slideTime: Time = Seconds(1) ) { test("reduceByKeyAndWindowInv - " + name) { - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime) .persist() -- cgit v1.2.3 From f803953998d6b931b266c69acab97b3ece628713 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 30 Dec 2012 12:43:06 -0800 Subject: Raise exception when hashing Java arrays (SPARK-597) --- core/src/main/scala/spark/PairRDDFunctions.scala | 27 +++++++++++++++++++++++ core/src/main/scala/spark/Partitioner.scala | 4 ++++ core/src/main/scala/spark/RDD.scala | 6 +++++ core/src/test/scala/spark/PartitioningSuite.scala | 21 ++++++++++++++++++ 4 files changed, 58 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index d3e206b353..413c944a66 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -52,6 +52,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( mergeCombiners: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true): RDD[(K, C)] = { + if (getKeyClass().isArray) { + if (mapSideCombine) { + throw new SparkException("Cannot use map-side combining with array keys.") + } + if (partitioner.isInstanceOf[HashPartitioner]) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (mapSideCombine) { @@ -92,6 +100,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { + + if (getKeyClass().isArray) { + throw new SparkException("reduceByKeyLocally() does not support array keys") + } + def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { val map = new JHashMap[K, V] for ((k, v) <- iter) { @@ -165,6 +178,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * be set to true. */ def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = { + if (getKeyClass().isArray) { + if (mapSideCombine) { + throw new SparkException("Cannot use map-side combining with array keys.") + } + if (partitioner.isInstanceOf[HashPartitioner]) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + } if (mapSideCombine) { def createCombiner(v: V) = ArrayBuffer(v) def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v @@ -336,6 +357,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * list of values for that key in `this` as well as `other`. */ def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } val cg = new CoGroupedRDD[K]( Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]), partitioner) @@ -352,6 +376,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } val cg = new CoGroupedRDD[K]( Seq(self.asInstanceOf[RDD[(_, _)]], other1.asInstanceOf[RDD[(_, _)]], diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala index b71021a082..9d5b966e1e 100644 --- a/core/src/main/scala/spark/Partitioner.scala +++ b/core/src/main/scala/spark/Partitioner.scala @@ -11,6 +11,10 @@ abstract class Partitioner extends Serializable { /** * A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`. + * + * Java arrays have hashCodes that are based on the arrays' identities rather than their contents, + * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will + * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { def numPartitions = partitions diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index d15c6f7396..7e38583391 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -417,6 +417,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): Map[T, Long] = { + if (elementClassManifest.erasure.isArray) { + throw new SparkException("countByValue() does not support arrays") + } // TODO: This should perhaps be distributed by default. def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { val map = new OLMap[T] @@ -445,6 +448,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial timeout: Long, confidence: Double = 0.95 ): PartialResult[Map[T, BoundedDouble]] = { + if (elementClassManifest.erasure.isArray) { + throw new SparkException("countByValueApprox() does not support arrays") + } val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => val map = new OLMap[T] while (iter.hasNext) { diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index 3dadc7acec..f09b602a7b 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -107,4 +107,25 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter { assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) } + + test("partitioning Java arrays should fail") { + sc = new SparkContext("local", "test") + val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x)) + val arrPairs: RDD[(Array[Int], Int)] = + sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) + + assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array")) + // We can't catch all usages of arrays, since they might occur inside other collections: + //assert(fails { arrPairs.distinct() }) + assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) + } } -- cgit v1.2.3 From feadaf72f44e7c66521c03171592671d4c441bda Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 31 Dec 2012 14:05:11 -0800 Subject: Mark key as not loading in CacheTracker even when compute() fails --- core/src/main/scala/spark/CacheTracker.scala | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 3d79078733..c8c4063cad 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -202,26 +202,26 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b loading.add(key) } } - // If we got here, we have to load the split - // Tell the master that we're doing so - //val host = System.getProperty("spark.hostname", Utils.localHostName) - //val future = trackerActor !! AddedToCache(rdd.id, split.index, host) - // TODO: fetch any remote copy of the split that may be available - // TODO: also register a listener for when it unloads - logInfo("Computing partition " + split) - val elements = new ArrayBuffer[Any] - elements ++= rdd.compute(split, context) try { + // If we got here, we have to load the split + // Tell the master that we're doing so + //val host = System.getProperty("spark.hostname", Utils.localHostName) + //val future = trackerActor !! AddedToCache(rdd.id, split.index, host) + // TODO: fetch any remote copy of the split that may be available + // TODO: also register a listener for when it unloads + val elements = new ArrayBuffer[Any] + logInfo("Computing partition " + split) + elements ++= rdd.compute(split, context) // Try to put this block in the blockManager blockManager.put(key, elements, storageLevel, true) //future.apply() // Wait for the reply from the cache tracker + return elements.iterator.asInstanceOf[Iterator[T]] } finally { loading.synchronized { loading.remove(key) loading.notifyAll() } } - return elements.iterator.asInstanceOf[Iterator[T]] } } -- cgit v1.2.3 From 21636ee4faf30126b36ad568753788327e634857 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 1 Jan 2013 07:52:31 -0800 Subject: Test with exception while computing cached RDD. --- core/src/test/scala/spark/RDDSuite.scala | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'core') diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 08da9a1c4d..45e6c5f840 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -88,6 +88,29 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(rdd.collect().toList === List(1, 2, 3, 4)) } + test("caching with failures") { + sc = new SparkContext("local", "test") + val onlySplit = new Split { override def index: Int = 0 } + var shouldFail = true + val rdd = new RDD[Int](sc) { + override def splits: Array[Split] = Array(onlySplit) + override val dependencies = List[Dependency[_]]() + override def compute(split: Split, context: TaskContext): Iterator[Int] = { + if (shouldFail) { + throw new Exception("injected failure") + } else { + return Array(1, 2, 3, 4).iterator + } + } + }.cache() + val thrown = intercept[Exception]{ + rdd.collect() + } + assert(thrown.getMessage.contains("injected failure")) + shouldFail = false + assert(rdd.collect().toList === List(1, 2, 3, 4)) + } + test("coalesced RDDs") { sc = new SparkContext("local", "test") val data = sc.parallelize(1 to 10, 10) -- cgit v1.2.3 From 58072a7340e20251ed810457bc67a79f106bae42 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 1 Jan 2013 07:59:16 -0800 Subject: Remove some dead comments --- core/src/main/scala/spark/CacheTracker.scala | 6 ------ 1 file changed, 6 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index c8c4063cad..04c26b2e40 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -204,17 +204,11 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b } try { // If we got here, we have to load the split - // Tell the master that we're doing so - //val host = System.getProperty("spark.hostname", Utils.localHostName) - //val future = trackerActor !! AddedToCache(rdd.id, split.index, host) - // TODO: fetch any remote copy of the split that may be available - // TODO: also register a listener for when it unloads val elements = new ArrayBuffer[Any] logInfo("Computing partition " + split) elements ++= rdd.compute(split, context) // Try to put this block in the blockManager blockManager.put(key, elements, storageLevel, true) - //future.apply() // Wait for the reply from the cache tracker return elements.iterator.asInstanceOf[Iterator[T]] } finally { loading.synchronized { -- cgit v1.2.3 From 170e451fbdd308ae77065bd9c0f2bd278abf0cb7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Jan 2013 13:52:14 -0800 Subject: Minor documentation and style fixes for PySpark. --- .../scala/spark/api/python/PythonPartitioner.scala | 4 +- .../main/scala/spark/api/python/PythonRDD.scala | 43 +++++++++++----- docs/index.md | 8 ++- docs/python-programming-guide.md | 3 +- pyspark/examples/kmeans.py | 13 +++-- pyspark/examples/logistic_regression.py | 57 ++++++++++++++++++++++ pyspark/examples/lr.py | 57 ---------------------- pyspark/examples/pi.py | 5 +- pyspark/examples/tc.py | 49 ------------------- pyspark/examples/transitive_closure.py | 50 +++++++++++++++++++ pyspark/examples/wordcount.py | 4 +- pyspark/pyspark/__init__.py | 13 ++++- 12 files changed, 172 insertions(+), 134 deletions(-) create mode 100755 pyspark/examples/logistic_regression.py delete mode 100755 pyspark/examples/lr.py delete mode 100644 pyspark/examples/tc.py create mode 100644 pyspark/examples/transitive_closure.py (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala index 2c829508e5..648d9402b0 100644 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -17,9 +17,9 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends val hashCode = { if (key.isInstanceOf[Array[Byte]]) { Arrays.hashCode(key.asInstanceOf[Array[Byte]]) - } - else + } else { key.hashCode() + } } val mod = hashCode % numPartitions if (mod < 0) { diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index dc48378fdc..19a039e330 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -13,8 +13,12 @@ import spark.rdd.PipedRDD private[spark] class PythonRDD[T: ClassManifest]( - parent: RDD[T], command: Seq[String], envVars: java.util.Map[String, String], - preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) + parent: RDD[T], + command: Seq[String], + envVars: java.util.Map[String, String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) { // Similar to Runtime.exec(), if we are given a single string, split it into words @@ -38,8 +42,8 @@ private[spark] class PythonRDD[T: ClassManifest]( // Add the environmental variables to the process. val currentEnvVars = pb.environment() - envVars.foreach { - case (variable, value) => currentEnvVars.put(variable, value) + for ((variable, value) <- envVars) { + currentEnvVars.put(variable, value) } val proc = pb.start() @@ -116,6 +120,10 @@ private[spark] class PythonRDD[T: ClassManifest]( val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } +/** + * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. + * This is used by PySpark's shuffle operations. + */ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Array[Byte], Array[Byte])](prev.context) { override def splits = prev.splits @@ -139,6 +147,16 @@ private[spark] object PythonRDD { * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. * The data format is a 32-bit integer representing the pickled object's length (in bytes), * followed by the pickled data. + * + * Pickle module: + * + * http://docs.python.org/2/library/pickle.html + * + * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules: + * + * http://hg.python.org/cpython/file/2.6/Lib/pickle.py + * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py + * * @param elem the object to write * @param dOut a data output stream */ @@ -201,15 +219,14 @@ private[spark] object PythonRDD { } private object Pickle { - def b(x: Int): Byte = x.asInstanceOf[Byte] - val PROTO: Byte = b(0x80) - val TWO: Byte = b(0x02) - val BINUNICODE : Byte = 'X' - val STOP : Byte = '.' - val TUPLE2 : Byte = b(0x86) - val EMPTY_LIST : Byte = ']' - val MARK : Byte = '(' - val APPENDS : Byte = 'e' + val PROTO: Byte = 0x80.toByte + val TWO: Byte = 0x02.toByte + val BINUNICODE: Byte = 'X' + val STOP: Byte = '.' + val TUPLE2: Byte = 0x86.toByte + val EMPTY_LIST: Byte = ']' + val MARK: Byte = '(' + val APPENDS: Byte = 'e' } private class ExtractValue extends spark.api.java.function.Function[(Array[Byte], diff --git a/docs/index.md b/docs/index.md index 33ab58a962..848b585333 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,7 +8,7 @@ TODO(andyk): Rewrite to make the Java API a first class part of the story. {% endcomment %} Spark is a MapReduce-like cluster computing framework designed for low-latency iterative jobs and interactive use from an interpreter. -It provides clean, language-integrated APIs in Scala, Java, and Python, with a rich array of parallel operators. +It provides clean, language-integrated APIs in [Scala](scala-programming-guide.html), [Java](java-programming-guide.html), and [Python](python-programming-guide.html), with a rich array of parallel operators. Spark can run on top of the [Apache Mesos](http://incubator.apache.org/mesos/) cluster manager, [Hadoop YARN](http://hadoop.apache.org/docs/r2.0.1-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html), Amazon EC2, or without an independent resource manager ("standalone mode"). @@ -61,6 +61,11 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`). * [Java Programming Guide](java-programming-guide.html): using Spark from Java * [Python Programming Guide](python-programming-guide.html): using Spark from Python +**API Docs:** + +* [Java/Scala (Scaladoc)](api/core/index.html) +* [Python (Epydoc)](api/pyspark/index.html) + **Deployment guides:** * [Running Spark on Amazon EC2](ec2-scripts.html): scripts that let you launch a cluster on EC2 in about 5 minutes @@ -73,7 +78,6 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`). * [Configuration](configuration.html): customize Spark via its configuration system * [Tuning Guide](tuning.html): best practices to optimize performance and memory use -* API Docs: [Java/Scala (Scaladoc)](api/core/index.html) and [Python (Epydoc)](api/pyspark/index.html) * [Bagel](bagel-programming-guide.html): an implementation of Google's Pregel on Spark * [Contributing to Spark](contributing-to-spark.html) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index b7c747f905..d88d4eb42d 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -17,8 +17,7 @@ There are a few key differences between the Python and Scala APIs: * Python is dynamically typed, so RDDs can hold objects of different types. * PySpark does not currently support the following Spark features: - Accumulators - - Special functions on RRDs of doubles, such as `mean` and `stdev` - - Approximate jobs / functions, such as `countApprox` and `sumApprox`. + - Special functions on RDDs of doubles, such as `mean` and `stdev` - `lookup` - `mapPartitionsWithSplit` - `persist` at storage levels other than `MEMORY_ONLY` diff --git a/pyspark/examples/kmeans.py b/pyspark/examples/kmeans.py index 9cc366f03c..ad2be21178 100644 --- a/pyspark/examples/kmeans.py +++ b/pyspark/examples/kmeans.py @@ -1,18 +1,21 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" import sys -from pyspark.context import SparkContext -from numpy import array, sum as np_sum +import numpy as np +from pyspark import SparkContext def parseVector(line): - return array([float(x) for x in line.split(' ')]) + return np.array([float(x) for x in line.split(' ')]) def closestPoint(p, centers): bestIndex = 0 closest = float("+inf") for i in range(len(centers)): - tempDist = np_sum((p - centers[i]) ** 2) + tempDist = np.sum((p - centers[i]) ** 2) if tempDist < closest: closest = tempDist bestIndex = i @@ -41,7 +44,7 @@ if __name__ == "__main__": newPoints = pointStats.map( lambda (x, (y, z)): (x, y / z)).collect() - tempDist = sum(np_sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) + tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) for (x, y) in newPoints: kPoints[x] = y diff --git a/pyspark/examples/logistic_regression.py b/pyspark/examples/logistic_regression.py new file mode 100755 index 0000000000..f13698a86f --- /dev/null +++ b/pyspark/examples/logistic_regression.py @@ -0,0 +1,57 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +from collections import namedtuple +from math import exp +from os.path import realpath +import sys + +import numpy as np +from pyspark import SparkContext + + +N = 100000 # Number of data points +D = 10 # Number of dimensions +R = 0.7 # Scaling factor +ITERATIONS = 5 +np.random.seed(42) + + +DataPoint = namedtuple("DataPoint", ['x', 'y']) +from lr import DataPoint # So that DataPoint is properly serialized + + +def generateData(): + def generatePoint(i): + y = -1 if i % 2 == 0 else 1 + x = np.random.normal(size=D) + (y * R) + return DataPoint(x, y) + return [generatePoint(i) for i in range(N)] + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonLR []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)]) + slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 + points = sc.parallelize(generateData(), slices).cache() + + # Initialize w to a random value + w = 2 * np.random.ranf(size=D) - 1 + print "Initial w: " + str(w) + + def add(x, y): + x += y + return x + + for i in range(1, ITERATIONS + 1): + print "On iteration %i" % i + + gradient = points.map(lambda p: + (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x + ).reduce(add) + w -= gradient + + print "Final w: " + str(w) diff --git a/pyspark/examples/lr.py b/pyspark/examples/lr.py deleted file mode 100755 index 5fca0266b8..0000000000 --- a/pyspark/examples/lr.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -This example requires numpy (http://www.numpy.org/) -""" -from collections import namedtuple -from math import exp -from os.path import realpath -import sys - -import numpy as np -from pyspark.context import SparkContext - - -N = 100000 # Number of data points -D = 10 # Number of dimensions -R = 0.7 # Scaling factor -ITERATIONS = 5 -np.random.seed(42) - - -DataPoint = namedtuple("DataPoint", ['x', 'y']) -from lr import DataPoint # So that DataPoint is properly serialized - - -def generateData(): - def generatePoint(i): - y = -1 if i % 2 == 0 else 1 - x = np.random.normal(size=D) + (y * R) - return DataPoint(x, y) - return [generatePoint(i) for i in range(N)] - - -if __name__ == "__main__": - if len(sys.argv) == 1: - print >> sys.stderr, \ - "Usage: PythonLR []" - exit(-1) - sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)]) - slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 - points = sc.parallelize(generateData(), slices).cache() - - # Initialize w to a random value - w = 2 * np.random.ranf(size=D) - 1 - print "Initial w: " + str(w) - - def add(x, y): - x += y - return x - - for i in range(1, ITERATIONS + 1): - print "On iteration %i" % i - - gradient = points.map(lambda p: - (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x - ).reduce(add) - w -= gradient - - print "Final w: " + str(w) diff --git a/pyspark/examples/pi.py b/pyspark/examples/pi.py index 348bbc5dce..127cba029b 100644 --- a/pyspark/examples/pi.py +++ b/pyspark/examples/pi.py @@ -1,13 +1,14 @@ import sys from random import random from operator import add -from pyspark.context import SparkContext + +from pyspark import SparkContext if __name__ == "__main__": if len(sys.argv) == 1: print >> sys.stderr, \ - "Usage: PythonPi []" + "Usage: PythonPi []" exit(-1) sc = SparkContext(sys.argv[1], "PythonPi") slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 diff --git a/pyspark/examples/tc.py b/pyspark/examples/tc.py deleted file mode 100644 index 9630e72b47..0000000000 --- a/pyspark/examples/tc.py +++ /dev/null @@ -1,49 +0,0 @@ -import sys -from random import Random -from pyspark.context import SparkContext - -numEdges = 200 -numVertices = 100 -rand = Random(42) - - -def generateGraph(): - edges = set() - while len(edges) < numEdges: - src = rand.randrange(0, numEdges) - dst = rand.randrange(0, numEdges) - if src != dst: - edges.add((src, dst)) - return edges - - -if __name__ == "__main__": - if len(sys.argv) == 1: - print >> sys.stderr, \ - "Usage: PythonTC []" - exit(-1) - sc = SparkContext(sys.argv[1], "PythonTC") - slices = sys.argv[2] if len(sys.argv) > 2 else 2 - tc = sc.parallelize(generateGraph(), slices).cache() - - # Linear transitive closure: each round grows paths by one edge, - # by joining the graph's edges with the already-discovered paths. - # e.g. join the path (y, z) from the TC with the edge (x, y) from - # the graph to obtain the path (x, z). - - # Because join() joins on keys, the edges are stored in reversed order. - edges = tc.map(lambda (x, y): (y, x)) - - oldCount = 0L - nextCount = tc.count() - while True: - oldCount = nextCount - # Perform the join, obtaining an RDD of (y, (z, x)) pairs, - # then project the result to obtain the new (x, z) paths. - new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) - tc = tc.union(new_edges).distinct().cache() - nextCount = tc.count() - if nextCount == oldCount: - break - - print "TC has %i edges" % tc.count() diff --git a/pyspark/examples/transitive_closure.py b/pyspark/examples/transitive_closure.py new file mode 100644 index 0000000000..73f7f8fbaf --- /dev/null +++ b/pyspark/examples/transitive_closure.py @@ -0,0 +1,50 @@ +import sys +from random import Random + +from pyspark import SparkContext + +numEdges = 200 +numVertices = 100 +rand = Random(42) + + +def generateGraph(): + edges = set() + while len(edges) < numEdges: + src = rand.randrange(0, numEdges) + dst = rand.randrange(0, numEdges) + if src != dst: + edges.add((src, dst)) + return edges + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonTC []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonTC") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + tc = sc.parallelize(generateGraph(), slices).cache() + + # Linear transitive closure: each round grows paths by one edge, + # by joining the graph's edges with the already-discovered paths. + # e.g. join the path (y, z) from the TC with the edge (x, y) from + # the graph to obtain the path (x, z). + + # Because join() joins on keys, the edges are stored in reversed order. + edges = tc.map(lambda (x, y): (y, x)) + + oldCount = 0L + nextCount = tc.count() + while True: + oldCount = nextCount + # Perform the join, obtaining an RDD of (y, (z, x)) pairs, + # then project the result to obtain the new (x, z) paths. + new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) + tc = tc.union(new_edges).distinct().cache() + nextCount = tc.count() + if nextCount == oldCount: + break + + print "TC has %i edges" % tc.count() diff --git a/pyspark/examples/wordcount.py b/pyspark/examples/wordcount.py index 8365c070e8..857160624b 100644 --- a/pyspark/examples/wordcount.py +++ b/pyspark/examples/wordcount.py @@ -1,6 +1,8 @@ import sys from operator import add -from pyspark.context import SparkContext + +from pyspark import SparkContext + if __name__ == "__main__": if len(sys.argv) < 3: diff --git a/pyspark/pyspark/__init__.py b/pyspark/pyspark/__init__.py index 8f8402b62b..1ab360a666 100644 --- a/pyspark/pyspark/__init__.py +++ b/pyspark/pyspark/__init__.py @@ -1,9 +1,20 @@ +""" +PySpark is a Python API for Spark. + +Public classes: + + - L{SparkContext} + Main entry point for Spark functionality. + - L{RDD} + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. +""" import sys import os sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "pyspark/lib/py4j0.7.egg")) from pyspark.context import SparkContext +from pyspark.rdd import RDD -__all__ = ["SparkContext"] +__all__ = ["SparkContext", "RDD"] -- cgit v1.2.3 From b58340dbd9a741331fc4c3829b08c093560056c2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Jan 2013 14:48:45 -0800 Subject: Rename top-level 'pyspark' directory to 'python' --- .../main/scala/spark/api/python/PythonRDD.scala | 2 +- docs/_plugins/copy_api_dirs.rb | 8 +- pyspark-shell | 3 + pyspark/.gitignore | 2 - pyspark/epydoc.conf | 19 - pyspark/examples/kmeans.py | 52 -- pyspark/examples/logistic_regression.py | 57 -- pyspark/examples/pi.py | 21 - pyspark/examples/transitive_closure.py | 50 -- pyspark/examples/wordcount.py | 19 - pyspark/lib/PY4J_LICENSE.txt | 27 - pyspark/lib/PY4J_VERSION.txt | 1 - pyspark/lib/py4j0.7.egg | Bin 191756 -> 0 bytes pyspark/lib/py4j0.7.jar | Bin 103286 -> 0 bytes pyspark/pyspark-shell | 3 - pyspark/pyspark/__init__.py | 20 - pyspark/pyspark/broadcast.py | 48 - pyspark/pyspark/cloudpickle.py | 974 --------------------- pyspark/pyspark/context.py | 158 ---- pyspark/pyspark/java_gateway.py | 38 - pyspark/pyspark/join.py | 92 -- pyspark/pyspark/rdd.py | 713 --------------- pyspark/pyspark/serializers.py | 78 -- pyspark/pyspark/shell.py | 33 - pyspark/pyspark/worker.py | 40 - pyspark/run-pyspark | 28 - python/.gitignore | 2 + python/epydoc.conf | 19 + python/examples/kmeans.py | 52 ++ python/examples/logistic_regression.py | 57 ++ python/examples/pi.py | 21 + python/examples/transitive_closure.py | 50 ++ python/examples/wordcount.py | 19 + python/lib/PY4J_LICENSE.txt | 27 + python/lib/PY4J_VERSION.txt | 1 + python/lib/py4j0.7.egg | Bin 0 -> 191756 bytes python/lib/py4j0.7.jar | Bin 0 -> 103286 bytes python/pyspark/__init__.py | 20 + python/pyspark/broadcast.py | 48 + python/pyspark/cloudpickle.py | 974 +++++++++++++++++++++ python/pyspark/context.py | 158 ++++ python/pyspark/java_gateway.py | 38 + python/pyspark/join.py | 92 ++ python/pyspark/rdd.py | 713 +++++++++++++++ python/pyspark/serializers.py | 78 ++ python/pyspark/shell.py | 33 + python/pyspark/worker.py | 40 + run | 2 +- run-pyspark | 28 + run2.cmd | 2 +- 50 files changed, 2480 insertions(+), 2480 deletions(-) create mode 100755 pyspark-shell delete mode 100644 pyspark/.gitignore delete mode 100644 pyspark/epydoc.conf delete mode 100644 pyspark/examples/kmeans.py delete mode 100755 pyspark/examples/logistic_regression.py delete mode 100644 pyspark/examples/pi.py delete mode 100644 pyspark/examples/transitive_closure.py delete mode 100644 pyspark/examples/wordcount.py delete mode 100644 pyspark/lib/PY4J_LICENSE.txt delete mode 100644 pyspark/lib/PY4J_VERSION.txt delete mode 100644 pyspark/lib/py4j0.7.egg delete mode 100644 pyspark/lib/py4j0.7.jar delete mode 100755 pyspark/pyspark-shell delete mode 100644 pyspark/pyspark/__init__.py delete mode 100644 pyspark/pyspark/broadcast.py delete mode 100644 pyspark/pyspark/cloudpickle.py delete mode 100644 pyspark/pyspark/context.py delete mode 100644 pyspark/pyspark/java_gateway.py delete mode 100644 pyspark/pyspark/join.py delete mode 100644 pyspark/pyspark/rdd.py delete mode 100644 pyspark/pyspark/serializers.py delete mode 100644 pyspark/pyspark/shell.py delete mode 100644 pyspark/pyspark/worker.py delete mode 100755 pyspark/run-pyspark create mode 100644 python/.gitignore create mode 100644 python/epydoc.conf create mode 100644 python/examples/kmeans.py create mode 100755 python/examples/logistic_regression.py create mode 100644 python/examples/pi.py create mode 100644 python/examples/transitive_closure.py create mode 100644 python/examples/wordcount.py create mode 100644 python/lib/PY4J_LICENSE.txt create mode 100644 python/lib/PY4J_VERSION.txt create mode 100644 python/lib/py4j0.7.egg create mode 100644 python/lib/py4j0.7.jar create mode 100644 python/pyspark/__init__.py create mode 100644 python/pyspark/broadcast.py create mode 100644 python/pyspark/cloudpickle.py create mode 100644 python/pyspark/context.py create mode 100644 python/pyspark/java_gateway.py create mode 100644 python/pyspark/join.py create mode 100644 python/pyspark/rdd.py create mode 100644 python/pyspark/serializers.py create mode 100644 python/pyspark/shell.py create mode 100644 python/pyspark/worker.py create mode 100755 run-pyspark (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 19a039e330..cf60d14f03 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -38,7 +38,7 @@ private[spark] class PythonRDD[T: ClassManifest]( override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = { val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") - val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) + val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py")) // Add the environmental variables to the process. val currentEnvVars = pb.environment() diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 577f3ebe70..c9ce589c1b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -30,8 +30,8 @@ if ENV['SKIP_SCALADOC'] != '1' end if ENV['SKIP_EPYDOC'] != '1' - puts "Moving to pyspark directory and building epydoc." - cd("../pyspark") + puts "Moving to python directory and building epydoc." + cd("../python") puts `epydoc --config epydoc.conf` puts "Moving back into docs dir." @@ -40,8 +40,8 @@ if ENV['SKIP_EPYDOC'] != '1' puts "echo making directory pyspark" mkdir_p "pyspark" - puts "cp -r ../pyspark/docs/. api/pyspark" - cp_r("../pyspark/docs/.", "api/pyspark") + puts "cp -r ../python/docs/. api/pyspark" + cp_r("../python/docs/.", "api/pyspark") cd("..") end diff --git a/pyspark-shell b/pyspark-shell new file mode 100755 index 0000000000..27aaac3a26 --- /dev/null +++ b/pyspark-shell @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +FWDIR="`dirname $0`" +exec $FWDIR/run-pyspark $FWDIR/python/pyspark/shell.py "$@" diff --git a/pyspark/.gitignore b/pyspark/.gitignore deleted file mode 100644 index 5c56e638f9..0000000000 --- a/pyspark/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -*.pyc -docs/ diff --git a/pyspark/epydoc.conf b/pyspark/epydoc.conf deleted file mode 100644 index 91ac984ba2..0000000000 --- a/pyspark/epydoc.conf +++ /dev/null @@ -1,19 +0,0 @@ -[epydoc] # Epydoc section marker (required by ConfigParser) - -# Information about the project. -name: PySpark -url: http://spark-project.org - -# The list of modules to document. Modules can be named using -# dotted names, module filenames, or package directory names. -# This option may be repeated. -modules: pyspark - -# Write html output to the directory "apidocs" -output: html -target: docs/ - -private: no - -exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers - pyspark.java_gateway pyspark.examples pyspark.shell diff --git a/pyspark/examples/kmeans.py b/pyspark/examples/kmeans.py deleted file mode 100644 index ad2be21178..0000000000 --- a/pyspark/examples/kmeans.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -This example requires numpy (http://www.numpy.org/) -""" -import sys - -import numpy as np -from pyspark import SparkContext - - -def parseVector(line): - return np.array([float(x) for x in line.split(' ')]) - - -def closestPoint(p, centers): - bestIndex = 0 - closest = float("+inf") - for i in range(len(centers)): - tempDist = np.sum((p - centers[i]) ** 2) - if tempDist < closest: - closest = tempDist - bestIndex = i - return bestIndex - - -if __name__ == "__main__": - if len(sys.argv) < 5: - print >> sys.stderr, \ - "Usage: PythonKMeans " - exit(-1) - sc = SparkContext(sys.argv[1], "PythonKMeans") - lines = sc.textFile(sys.argv[2]) - data = lines.map(parseVector).cache() - K = int(sys.argv[3]) - convergeDist = float(sys.argv[4]) - - kPoints = data.takeSample(False, K, 34) - tempDist = 1.0 - - while tempDist > convergeDist: - closest = data.map( - lambda p : (closestPoint(p, kPoints), (p, 1))) - pointStats = closest.reduceByKey( - lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) - newPoints = pointStats.map( - lambda (x, (y, z)): (x, y / z)).collect() - - tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) - - for (x, y) in newPoints: - kPoints[x] = y - - print "Final centers: " + str(kPoints) diff --git a/pyspark/examples/logistic_regression.py b/pyspark/examples/logistic_regression.py deleted file mode 100755 index f13698a86f..0000000000 --- a/pyspark/examples/logistic_regression.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -This example requires numpy (http://www.numpy.org/) -""" -from collections import namedtuple -from math import exp -from os.path import realpath -import sys - -import numpy as np -from pyspark import SparkContext - - -N = 100000 # Number of data points -D = 10 # Number of dimensions -R = 0.7 # Scaling factor -ITERATIONS = 5 -np.random.seed(42) - - -DataPoint = namedtuple("DataPoint", ['x', 'y']) -from lr import DataPoint # So that DataPoint is properly serialized - - -def generateData(): - def generatePoint(i): - y = -1 if i % 2 == 0 else 1 - x = np.random.normal(size=D) + (y * R) - return DataPoint(x, y) - return [generatePoint(i) for i in range(N)] - - -if __name__ == "__main__": - if len(sys.argv) == 1: - print >> sys.stderr, \ - "Usage: PythonLR []" - exit(-1) - sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)]) - slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 - points = sc.parallelize(generateData(), slices).cache() - - # Initialize w to a random value - w = 2 * np.random.ranf(size=D) - 1 - print "Initial w: " + str(w) - - def add(x, y): - x += y - return x - - for i in range(1, ITERATIONS + 1): - print "On iteration %i" % i - - gradient = points.map(lambda p: - (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x - ).reduce(add) - w -= gradient - - print "Final w: " + str(w) diff --git a/pyspark/examples/pi.py b/pyspark/examples/pi.py deleted file mode 100644 index 127cba029b..0000000000 --- a/pyspark/examples/pi.py +++ /dev/null @@ -1,21 +0,0 @@ -import sys -from random import random -from operator import add - -from pyspark import SparkContext - - -if __name__ == "__main__": - if len(sys.argv) == 1: - print >> sys.stderr, \ - "Usage: PythonPi []" - exit(-1) - sc = SparkContext(sys.argv[1], "PythonPi") - slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 - n = 100000 * slices - def f(_): - x = random() * 2 - 1 - y = random() * 2 - 1 - return 1 if x ** 2 + y ** 2 < 1 else 0 - count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) - print "Pi is roughly %f" % (4.0 * count / n) diff --git a/pyspark/examples/transitive_closure.py b/pyspark/examples/transitive_closure.py deleted file mode 100644 index 73f7f8fbaf..0000000000 --- a/pyspark/examples/transitive_closure.py +++ /dev/null @@ -1,50 +0,0 @@ -import sys -from random import Random - -from pyspark import SparkContext - -numEdges = 200 -numVertices = 100 -rand = Random(42) - - -def generateGraph(): - edges = set() - while len(edges) < numEdges: - src = rand.randrange(0, numEdges) - dst = rand.randrange(0, numEdges) - if src != dst: - edges.add((src, dst)) - return edges - - -if __name__ == "__main__": - if len(sys.argv) == 1: - print >> sys.stderr, \ - "Usage: PythonTC []" - exit(-1) - sc = SparkContext(sys.argv[1], "PythonTC") - slices = sys.argv[2] if len(sys.argv) > 2 else 2 - tc = sc.parallelize(generateGraph(), slices).cache() - - # Linear transitive closure: each round grows paths by one edge, - # by joining the graph's edges with the already-discovered paths. - # e.g. join the path (y, z) from the TC with the edge (x, y) from - # the graph to obtain the path (x, z). - - # Because join() joins on keys, the edges are stored in reversed order. - edges = tc.map(lambda (x, y): (y, x)) - - oldCount = 0L - nextCount = tc.count() - while True: - oldCount = nextCount - # Perform the join, obtaining an RDD of (y, (z, x)) pairs, - # then project the result to obtain the new (x, z) paths. - new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) - tc = tc.union(new_edges).distinct().cache() - nextCount = tc.count() - if nextCount == oldCount: - break - - print "TC has %i edges" % tc.count() diff --git a/pyspark/examples/wordcount.py b/pyspark/examples/wordcount.py deleted file mode 100644 index 857160624b..0000000000 --- a/pyspark/examples/wordcount.py +++ /dev/null @@ -1,19 +0,0 @@ -import sys -from operator import add - -from pyspark import SparkContext - - -if __name__ == "__main__": - if len(sys.argv) < 3: - print >> sys.stderr, \ - "Usage: PythonWordCount " - exit(-1) - sc = SparkContext(sys.argv[1], "PythonWordCount") - lines = sc.textFile(sys.argv[2], 1) - counts = lines.flatMap(lambda x: x.split(' ')) \ - .map(lambda x: (x, 1)) \ - .reduceByKey(add) - output = counts.collect() - for (word, count) in output: - print "%s : %i" % (word, count) diff --git a/pyspark/lib/PY4J_LICENSE.txt b/pyspark/lib/PY4J_LICENSE.txt deleted file mode 100644 index a70279ca14..0000000000 --- a/pyspark/lib/PY4J_LICENSE.txt +++ /dev/null @@ -1,27 +0,0 @@ - -Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -- Redistributions of source code must retain the above copyright notice, this -list of conditions and the following disclaimer. - -- Redistributions in binary form must reproduce the above copyright notice, -this list of conditions and the following disclaimer in the documentation -and/or other materials provided with the distribution. - -- The name of the author may not be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. diff --git a/pyspark/lib/PY4J_VERSION.txt b/pyspark/lib/PY4J_VERSION.txt deleted file mode 100644 index 04a0cd52a8..0000000000 --- a/pyspark/lib/PY4J_VERSION.txt +++ /dev/null @@ -1 +0,0 @@ -b7924aabe9c5e63f0a4d8bbd17019534c7ec014e diff --git a/pyspark/lib/py4j0.7.egg b/pyspark/lib/py4j0.7.egg deleted file mode 100644 index f8a339d8ee..0000000000 Binary files a/pyspark/lib/py4j0.7.egg and /dev/null differ diff --git a/pyspark/lib/py4j0.7.jar b/pyspark/lib/py4j0.7.jar deleted file mode 100644 index 73b7ddb7d1..0000000000 Binary files a/pyspark/lib/py4j0.7.jar and /dev/null differ diff --git a/pyspark/pyspark-shell b/pyspark/pyspark-shell deleted file mode 100755 index e3736826e8..0000000000 --- a/pyspark/pyspark-shell +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env bash -FWDIR="`dirname $0`" -exec $FWDIR/run-pyspark $FWDIR/pyspark/shell.py "$@" diff --git a/pyspark/pyspark/__init__.py b/pyspark/pyspark/__init__.py deleted file mode 100644 index 1ab360a666..0000000000 --- a/pyspark/pyspark/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -PySpark is a Python API for Spark. - -Public classes: - - - L{SparkContext} - Main entry point for Spark functionality. - - L{RDD} - A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. -""" -import sys -import os -sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "pyspark/lib/py4j0.7.egg")) - - -from pyspark.context import SparkContext -from pyspark.rdd import RDD - - -__all__ = ["SparkContext", "RDD"] diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py deleted file mode 100644 index 93876fa738..0000000000 --- a/pyspark/pyspark/broadcast.py +++ /dev/null @@ -1,48 +0,0 @@ -""" ->>> from pyspark.context import SparkContext ->>> sc = SparkContext('local', 'test') ->>> b = sc.broadcast([1, 2, 3, 4, 5]) ->>> b.value -[1, 2, 3, 4, 5] - ->>> from pyspark.broadcast import _broadcastRegistry ->>> _broadcastRegistry[b.bid] = b ->>> from cPickle import dumps, loads ->>> loads(dumps(b)).value -[1, 2, 3, 4, 5] - ->>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() -[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] - ->>> large_broadcast = sc.broadcast(list(range(10000))) -""" -# Holds broadcasted data received from Java, keyed by its id. -_broadcastRegistry = {} - - -def _from_id(bid): - from pyspark.broadcast import _broadcastRegistry - if bid not in _broadcastRegistry: - raise Exception("Broadcast variable '%s' not loaded!" % bid) - return _broadcastRegistry[bid] - - -class Broadcast(object): - def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): - self.value = value - self.bid = bid - self._jbroadcast = java_broadcast - self._pickle_registry = pickle_registry - - def __reduce__(self): - self._pickle_registry.add(self) - return (_from_id, (self.bid, )) - - -def _test(): - import doctest - doctest.testmod() - - -if __name__ == "__main__": - _test() diff --git a/pyspark/pyspark/cloudpickle.py b/pyspark/pyspark/cloudpickle.py deleted file mode 100644 index 6a7c23a069..0000000000 --- a/pyspark/pyspark/cloudpickle.py +++ /dev/null @@ -1,974 +0,0 @@ -""" -This class is defined to override standard pickle functionality - -The goals of it follow: --Serialize lambdas and nested functions to compiled byte code --Deal with main module correctly --Deal with other non-serializable objects - -It does not include an unpickler, as standard python unpickling suffices. - -This module was extracted from the `cloud` package, developed by `PiCloud, Inc. -`_. - -Copyright (c) 2012, Regents of the University of California. -Copyright (c) 2009 `PiCloud, Inc. `_. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - * Neither the name of the University of California, Berkeley nor the - names of its contributors may be used to endorse or promote - products derived from this software without specific prior written - permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED -TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -""" - - -import operator -import os -import pickle -import struct -import sys -import types -from functools import partial -import itertools -from copy_reg import _extension_registry, _inverted_registry, _extension_cache -import new -import dis -import traceback - -#relevant opcodes -STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) -DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) -LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) -GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] - -HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) -EXTENDED_ARG = chr(dis.EXTENDED_ARG) - -import logging -cloudLog = logging.getLogger("Cloud.Transport") - -try: - import ctypes -except (MemoryError, ImportError): - logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True) - ctypes = None - PyObject_HEAD = None -else: - - # for reading internal structures - PyObject_HEAD = [ - ('ob_refcnt', ctypes.c_size_t), - ('ob_type', ctypes.c_void_p), - ] - - -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO - -# These helper functions were copied from PiCloud's util module. -def islambda(func): - return getattr(func,'func_name') == '' - -def xrange_params(xrangeobj): - """Returns a 3 element tuple describing the xrange start, step, and len - respectively - - Note: Only guarentees that elements of xrange are the same. parameters may - be different. - e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same - though w/ iteration - """ - - xrange_len = len(xrangeobj) - if not xrange_len: #empty - return (0,1,0) - start = xrangeobj[0] - if xrange_len == 1: #one element - return start, 1, 1 - return (start, xrangeobj[1] - xrangeobj[0], xrange_len) - -#debug variables intended for developer use: -printSerialization = False -printMemoization = False - -useForcedImports = True #Should I use forced imports for tracking? - - - -class CloudPickler(pickle.Pickler): - - dispatch = pickle.Pickler.dispatch.copy() - savedForceImports = False - savedDjangoEnv = False #hack tro transport django environment - - def __init__(self, file, protocol=None, min_size_to_save= 0): - pickle.Pickler.__init__(self,file,protocol) - self.modules = set() #set of modules needed to depickle - self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env - - def dump(self, obj): - # note: not thread safe - # minimal side-effects, so not fixing - recurse_limit = 3000 - base_recurse = sys.getrecursionlimit() - if base_recurse < recurse_limit: - sys.setrecursionlimit(recurse_limit) - self.inject_addons() - try: - return pickle.Pickler.dump(self, obj) - except RuntimeError, e: - if 'recursion' in e.args[0]: - msg = """Could not pickle object as excessively deep recursion required. - Try _fast_serialization=2 or contact PiCloud support""" - raise pickle.PicklingError(msg) - finally: - new_recurse = sys.getrecursionlimit() - if new_recurse == recurse_limit: - sys.setrecursionlimit(base_recurse) - - def save_buffer(self, obj): - """Fallback to save_string""" - pickle.Pickler.save_string(self,str(obj)) - dispatch[buffer] = save_buffer - - #block broken objects - def save_unsupported(self, obj, pack=None): - raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) - dispatch[types.GeneratorType] = save_unsupported - - #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it - try: - slice(0,1).__reduce__() - except TypeError: #can't pickle - - dispatch[slice] = save_unsupported - - #itertools objects do not pickle! - for v in itertools.__dict__.values(): - if type(v) is type: - dispatch[v] = save_unsupported - - - def save_dict(self, obj): - """hack fix - If the dict is a global, deal with it in a special way - """ - #print 'saving', obj - if obj is __builtins__: - self.save_reduce(_get_module_builtins, (), obj=obj) - else: - pickle.Pickler.save_dict(self, obj) - dispatch[pickle.DictionaryType] = save_dict - - - def save_module(self, obj, pack=struct.pack): - """ - Save a module as an import - """ - #print 'try save import', obj.__name__ - self.modules.add(obj) - self.save_reduce(subimport,(obj.__name__,), obj=obj) - dispatch[types.ModuleType] = save_module #new type - - def save_codeobject(self, obj, pack=struct.pack): - """ - Save a code object - """ - #print 'try to save codeobj: ', obj - args = ( - obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, - obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, - obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars - ) - self.save_reduce(types.CodeType, args, obj=obj) - dispatch[types.CodeType] = save_codeobject #new type - - def save_function(self, obj, name=None, pack=struct.pack): - """ Registered with the dispatch to handle all function types. - - Determines what kind of function obj is (e.g. lambda, defined at - interactive prompt, etc) and handles the pickling appropriately. - """ - write = self.write - - name = obj.__name__ - modname = pickle.whichmodule(obj, name) - #print 'which gives %s %s %s' % (modname, obj, name) - try: - themodule = sys.modules[modname] - except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__ - modname = '__main__' - - if modname == '__main__': - themodule = None - - if themodule: - self.modules.add(themodule) - - if not self.savedDjangoEnv: - #hack for django - if we detect the settings module, we transport it - django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '') - if django_settings: - django_mod = sys.modules.get(django_settings) - if django_mod: - cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name) - self.savedDjangoEnv = True - self.modules.add(django_mod) - write(pickle.MARK) - self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod) - write(pickle.POP_MARK) - - - # if func is lambda, def'ed at prompt, is in main, or is nested, then - # we'll pickle the actual function object rather than simply saving a - # reference (as is done in default pickler), via save_function_tuple. - if islambda(obj) or obj.func_code.co_filename == '' or themodule == None: - #Force server to import modules that have been imported in main - modList = None - if themodule == None and not self.savedForceImports: - mainmod = sys.modules['__main__'] - if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): - modList = list(mainmod.___pyc_forcedImports__) - self.savedForceImports = True - self.save_function_tuple(obj, modList) - return - else: # func is nested - klass = getattr(themodule, name, None) - if klass is None or klass is not obj: - self.save_function_tuple(obj, [themodule]) - return - - if obj.__dict__: - # essentially save_reduce, but workaround needed to avoid recursion - self.save(_restore_attr) - write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n') - self.memoize(obj) - self.save(obj.__dict__) - write(pickle.TUPLE + pickle.REDUCE) - else: - write(pickle.GLOBAL + modname + '\n' + name + '\n') - self.memoize(obj) - dispatch[types.FunctionType] = save_function - - def save_function_tuple(self, func, forced_imports): - """ Pickles an actual func object. - - A func comprises: code, globals, defaults, closure, and dict. We - extract and save these, injecting reducing functions at certain points - to recreate the func object. Keep in mind that some of these pieces - can contain a ref to the func itself. Thus, a naive save on these - pieces could trigger an infinite loop of save's. To get around that, - we first create a skeleton func object using just the code (this is - safe, since this won't contain a ref to the func), and memoize it as - soon as it's created. The other stuff can then be filled in later. - """ - save = self.save - write = self.write - - # save the modules (if any) - if forced_imports: - write(pickle.MARK) - save(_modules_to_main) - #print 'forced imports are', forced_imports - - forced_names = map(lambda m: m.__name__, forced_imports) - save((forced_names,)) - - #save((forced_imports,)) - write(pickle.REDUCE) - write(pickle.POP_MARK) - - code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) - - save(_fill_function) # skeleton function updater - write(pickle.MARK) # beginning of tuple that _fill_function expects - - # create a skeleton function object and memoize it - save(_make_skel_func) - save((code, len(closure), base_globals)) - write(pickle.REDUCE) - self.memoize(func) - - # save the rest of the func data needed by _fill_function - save(f_globals) - save(defaults) - save(closure) - save(dct) - write(pickle.TUPLE) - write(pickle.REDUCE) # applies _fill_function on the tuple - - @staticmethod - def extract_code_globals(co): - """ - Find all globals names read or written to by codeblock co - """ - code = co.co_code - names = co.co_names - out_names = set() - - n = len(code) - i = 0 - extended_arg = 0 - while i < n: - op = code[i] - - i = i+1 - if op >= HAVE_ARGUMENT: - oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg - extended_arg = 0 - i = i+2 - if op == EXTENDED_ARG: - extended_arg = oparg*65536L - if op in GLOBAL_OPS: - out_names.add(names[oparg]) - #print 'extracted', out_names, ' from ', names - return out_names - - def extract_func_data(self, func): - """ - Turn the function into a tuple of data necessary to recreate it: - code, globals, defaults, closure, dict - """ - code = func.func_code - - # extract all global ref's - func_global_refs = CloudPickler.extract_code_globals(code) - if code.co_consts: # see if nested function have any global refs - for const in code.co_consts: - if type(const) is types.CodeType and const.co_names: - func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const)) - # process all variables referenced by global environment - f_globals = {} - for var in func_global_refs: - #Some names, such as class functions are not global - we don't need them - if func.func_globals.has_key(var): - f_globals[var] = func.func_globals[var] - - # defaults requires no processing - defaults = func.func_defaults - - def get_contents(cell): - try: - return cell.cell_contents - except ValueError, e: #cell is empty error on not yet assigned - raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope') - - - # process closure - if func.func_closure: - closure = map(get_contents, func.func_closure) - else: - closure = [] - - # save the dict - dct = func.func_dict - - if printSerialization: - outvars = ['code: ' + str(code) ] - outvars.append('globals: ' + str(f_globals)) - outvars.append('defaults: ' + str(defaults)) - outvars.append('closure: ' + str(closure)) - print 'function ', func, 'is extracted to: ', ', '.join(outvars) - - base_globals = self.globals_ref.get(id(func.func_globals), {}) - self.globals_ref[id(func.func_globals)] = base_globals - - return (code, f_globals, defaults, closure, dct, base_globals) - - def save_global(self, obj, name=None, pack=struct.pack): - write = self.write - memo = self.memo - - if name is None: - name = obj.__name__ - - modname = getattr(obj, "__module__", None) - if modname is None: - modname = pickle.whichmodule(obj, name) - - try: - __import__(modname) - themodule = sys.modules[modname] - except (ImportError, KeyError, AttributeError): #should never occur - raise pickle.PicklingError( - "Can't pickle %r: Module %s cannot be found" % - (obj, modname)) - - if modname == '__main__': - themodule = None - - if themodule: - self.modules.add(themodule) - - sendRef = True - typ = type(obj) - #print 'saving', obj, typ - try: - try: #Deal with case when getattribute fails with exceptions - klass = getattr(themodule, name) - except (AttributeError): - if modname == '__builtin__': #new.* are misrepeported - modname = 'new' - __import__(modname) - themodule = sys.modules[modname] - try: - klass = getattr(themodule, name) - except AttributeError, a: - #print themodule, name, obj, type(obj) - raise pickle.PicklingError("Can't pickle builtin %s" % obj) - else: - raise - - except (ImportError, KeyError, AttributeError): - if typ == types.TypeType or typ == types.ClassType: - sendRef = False - else: #we can't deal with this - raise - else: - if klass is not obj and (typ == types.TypeType or typ == types.ClassType): - sendRef = False - if not sendRef: - #note: Third party types might crash this - add better checks! - d = dict(obj.__dict__) #copy dict proxy to a dict - if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties - d.pop('__dict__',None) - d.pop('__weakref__',None) - - # hack as __new__ is stored differently in the __dict__ - new_override = d.get('__new__', None) - if new_override: - d['__new__'] = obj.__new__ - - self.save_reduce(type(obj),(obj.__name__,obj.__bases__, - d),obj=obj) - #print 'internal reduce dask %s %s' % (obj, d) - return - - if self.proto >= 2: - code = _extension_registry.get((modname, name)) - if code: - assert code > 0 - if code <= 0xff: - write(pickle.EXT1 + chr(code)) - elif code <= 0xffff: - write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8)) - else: - write(pickle.EXT4 + pack("= 2 and getattr(func, "__name__", "") == "__newobj__": - #Added fix to allow transient - cls = args[0] - if not hasattr(cls, "__new__"): - raise pickle.PicklingError( - "args[0] from __newobj__ args has no __new__") - if obj is not None and cls is not obj.__class__: - raise pickle.PicklingError( - "args[0] from __newobj__ args has the wrong class") - args = args[1:] - save(cls) - - #Don't pickle transient entries - if hasattr(obj, '__transient__'): - transient = obj.__transient__ - state = state.copy() - - for k in list(state.keys()): - if k in transient: - del state[k] - - save(args) - write(pickle.NEWOBJ) - else: - save(func) - save(args) - write(pickle.REDUCE) - - if obj is not None: - self.memoize(obj) - - # More new special cases (that work with older protocols as - # well): when __reduce__ returns a tuple with 4 or 5 items, - # the 4th and 5th item should be iterators that provide list - # items and dict items (as (key, value) tuples), or None. - - if listitems is not None: - self._batch_appends(listitems) - - if dictitems is not None: - self._batch_setitems(dictitems) - - if state is not None: - #print 'obj %s has state %s' % (obj, state) - save(state) - write(pickle.BUILD) - - - def save_xrange(self, obj): - """Save an xrange object in python 2.5 - Python 2.6 supports this natively - """ - range_params = xrange_params(obj) - self.save_reduce(_build_xrange,range_params) - - #python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it - try: - xrange(0).__reduce__() - except TypeError: #can't pickle -- use PiCloud pickler - dispatch[xrange] = save_xrange - - def save_partial(self, obj): - """Partial objects do not serialize correctly in python2.x -- this fixes the bugs""" - self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords)) - - if sys.version_info < (2,7): #2.7 supports partial pickling - dispatch[partial] = save_partial - - - def save_file(self, obj): - """Save a file""" - import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute - from ..transport.adapter import SerializingAdapter - - if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): - raise pickle.PicklingError("Cannot pickle files that do not map to an actual file") - if obj.name == '': - return self.save_reduce(getattr, (sys,'stdout'), obj=obj) - if obj.name == '': - return self.save_reduce(getattr, (sys,'stderr'), obj=obj) - if obj.name == '': - raise pickle.PicklingError("Cannot pickle standard input") - if hasattr(obj, 'isatty') and obj.isatty(): - raise pickle.PicklingError("Cannot pickle files that map to tty objects") - if 'r' not in obj.mode: - raise pickle.PicklingError("Cannot pickle files that are not opened for reading") - name = obj.name - try: - fsize = os.stat(name).st_size - except OSError: - raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name) - - if obj.closed: - #create an empty closed string io - retval = pystringIO.StringIO("") - retval.close() - elif not fsize: #empty file - retval = pystringIO.StringIO("") - try: - tmpfile = file(name) - tst = tmpfile.read(1) - except IOError: - raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) - tmpfile.close() - if tst != '': - raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name) - elif fsize > SerializingAdapter.max_transmit_data: - raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" % - (name,SerializingAdapter.max_transmit_data)) - else: - try: - tmpfile = file(name) - contents = tmpfile.read(SerializingAdapter.max_transmit_data) - tmpfile.close() - except IOError: - raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) - retval = pystringIO.StringIO(contents) - curloc = obj.tell() - retval.seek(curloc) - - retval.name = name - self.save(retval) #save stringIO - self.memoize(obj) - - dispatch[file] = save_file - """Special functions for Add-on libraries""" - - def inject_numpy(self): - numpy = sys.modules.get('numpy') - if not numpy or not hasattr(numpy, 'ufunc'): - return - self.dispatch[numpy.ufunc] = self.__class__.save_ufunc - - numpy_tst_mods = ['numpy', 'scipy.special'] - def save_ufunc(self, obj): - """Hack function for saving numpy ufunc objects""" - name = obj.__name__ - for tst_mod_name in self.numpy_tst_mods: - tst_mod = sys.modules.get(tst_mod_name, None) - if tst_mod: - if name in tst_mod.__dict__: - self.save_reduce(_getobject, (tst_mod_name, name)) - return - raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj)) - - def inject_timeseries(self): - """Handle bugs with pickling scikits timeseries""" - tseries = sys.modules.get('scikits.timeseries.tseries') - if not tseries or not hasattr(tseries, 'Timeseries'): - return - self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries - - def save_timeseries(self, obj): - import scikits.timeseries.tseries as ts - - func, reduce_args, state = obj.__reduce__() - if func != ts._tsreconstruct: - raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func)) - state = (1, - obj.shape, - obj.dtype, - obj.flags.fnc, - obj._data.tostring(), - ts.getmaskarray(obj).tostring(), - obj._fill_value, - obj._dates.shape, - obj._dates.__array__().tostring(), - obj._dates.dtype, #added -- preserve type - obj.freq, - obj._optinfo, - ) - return self.save_reduce(_genTimeSeries, (reduce_args, state)) - - def inject_email(self): - """Block email LazyImporters from being saved""" - email = sys.modules.get('email') - if not email: - return - self.dispatch[email.LazyImporter] = self.__class__.save_unsupported - - def inject_addons(self): - """Plug in system. Register additional pickling functions if modules already loaded""" - self.inject_numpy() - self.inject_timeseries() - self.inject_email() - - """Python Imaging Library""" - def save_image(self, obj): - if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \ - and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()): - #if image not loaded yet -- lazy load - self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj) - else: - #image is loaded - just transmit it over - self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj) - - """ - def memoize(self, obj): - pickle.Pickler.memoize(self, obj) - if printMemoization: - print 'memoizing ' + str(obj) - """ - - - -# Shorthands for legacy support - -def dump(obj, file, protocol=2): - CloudPickler(file, protocol).dump(obj) - -def dumps(obj, protocol=2): - file = StringIO() - - cp = CloudPickler(file,protocol) - cp.dump(obj) - - #print 'cloud dumped', str(obj), str(cp.modules) - - return file.getvalue() - - -#hack for __import__ not working as desired -def subimport(name): - __import__(name) - return sys.modules[name] - -#hack to load django settings: -def django_settings_load(name): - modified_env = False - - if 'DJANGO_SETTINGS_MODULE' not in os.environ: - os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps - modified_env = True - try: - module = subimport(name) - except Exception, i: - print >> sys.stderr, 'Cloud not import django settings %s:' % (name) - print_exec(sys.stderr) - if modified_env: - del os.environ['DJANGO_SETTINGS_MODULE'] - else: - #add project directory to sys,path: - if hasattr(module,'__file__'): - dirname = os.path.split(module.__file__)[0] + '/' - sys.path.append(dirname) - -# restores function attributes -def _restore_attr(obj, attr): - for key, val in attr.items(): - setattr(obj, key, val) - return obj - -def _get_module_builtins(): - return pickle.__builtins__ - -def print_exec(stream): - ei = sys.exc_info() - traceback.print_exception(ei[0], ei[1], ei[2], None, stream) - -def _modules_to_main(modList): - """Force every module in modList to be placed into main""" - if not modList: - return - - main = sys.modules['__main__'] - for modname in modList: - if type(modname) is str: - try: - mod = __import__(modname) - except Exception, i: #catch all... - sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \ -A version mismatch is likely. Specific error was:\n' % modname) - print_exec(sys.stderr) - else: - setattr(main,mod.__name__, mod) - else: - #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD) - #In old version actual module was sent - setattr(main,modname.__name__, modname) - -#object generators: -def _build_xrange(start, step, len): - """Built xrange explicitly""" - return xrange(start, start + step*len, step) - -def _genpartial(func, args, kwds): - if not args: - args = () - if not kwds: - kwds = {} - return partial(func, *args, **kwds) - - -def _fill_function(func, globals, defaults, closure, dict): - """ Fills in the rest of function data into the skeleton function object - that were created via _make_skel_func(). - """ - func.func_globals.update(globals) - func.func_defaults = defaults - func.func_dict = dict - - if len(closure) != len(func.func_closure): - raise pickle.UnpicklingError("closure lengths don't match up") - for i in range(len(closure)): - _change_cell_value(func.func_closure[i], closure[i]) - - return func - -def _make_skel_func(code, num_closures, base_globals = None): - """ Creates a skeleton function object that contains just the provided - code and the correct number of cells in func_closure. All other - func attributes (e.g. func_globals) are empty. - """ - #build closure (cells): - if not ctypes: - raise Exception('ctypes failed to import; cannot build function') - - cellnew = ctypes.pythonapi.PyCell_New - cellnew.restype = ctypes.py_object - cellnew.argtypes = (ctypes.py_object,) - dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures))) - - if base_globals is None: - base_globals = {} - base_globals['__builtins__'] = __builtins__ - - return types.FunctionType(code, base_globals, - None, None, dummy_closure) - -# this piece of opaque code is needed below to modify 'cell' contents -cell_changer_code = new.code( - 1, 1, 2, 0, - ''.join([ - chr(dis.opmap['LOAD_FAST']), '\x00\x00', - chr(dis.opmap['DUP_TOP']), - chr(dis.opmap['STORE_DEREF']), '\x00\x00', - chr(dis.opmap['RETURN_VALUE']) - ]), - (), (), ('newval',), '', 'cell_changer', 1, '', ('c',), () -) - -def _change_cell_value(cell, newval): - """ Changes the contents of 'cell' object to newval """ - return new.function(cell_changer_code, {}, None, (), (cell,))(newval) - -"""Constructors for 3rd party libraries -Note: These can never be renamed due to client compatibility issues""" - -def _getobject(modname, attribute): - mod = __import__(modname) - return mod.__dict__[attribute] - -def _generateImage(size, mode, str_rep): - """Generate image from string representation""" - import Image - i = Image.new(mode, size) - i.fromstring(str_rep) - return i - -def _lazyloadImage(fp): - import Image - fp.seek(0) #works in almost any case - return Image.open(fp) - -"""Timeseries""" -def _genTimeSeries(reduce_args, state): - import scikits.timeseries.tseries as ts - from numpy import ndarray - from numpy.ma import MaskedArray - - - time_series = ts._tsreconstruct(*reduce_args) - - #from setstate modified - (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state - #print 'regenerating %s' % dtyp - - MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv)) - _dates = time_series._dates - #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ - ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm)) - _dates.freq = frq - _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None, - toobj=None, toord=None, tostr=None)) - # Update the _optinfo dictionary - time_series._optinfo.update(infodict) - return time_series - diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py deleted file mode 100644 index 6172d69dcf..0000000000 --- a/pyspark/pyspark/context.py +++ /dev/null @@ -1,158 +0,0 @@ -import os -import atexit -from tempfile import NamedTemporaryFile - -from pyspark.broadcast import Broadcast -from pyspark.java_gateway import launch_gateway -from pyspark.serializers import dump_pickle, write_with_length, batched -from pyspark.rdd import RDD - -from py4j.java_collections import ListConverter - - -class SparkContext(object): - """ - Main entry point for Spark functionality. A SparkContext represents the - connection to a Spark cluster, and can be used to create L{RDD}s and - broadcast variables on that cluster. - """ - - gateway = launch_gateway() - jvm = gateway.jvm - _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile - _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile - - def __init__(self, master, jobName, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024): - """ - Create a new SparkContext. - - @param master: Cluster URL to connect to - (e.g. mesos://host:port, spark://host:port, local[4]). - @param jobName: A name for your job, to display on the cluster web UI - @param sparkHome: Location where Spark is installed on cluster nodes. - @param pyFiles: Collection of .zip or .py files to send to the cluster - and add to PYTHONPATH. These can be paths on the local file - system or HDFS, HTTP, HTTPS, or FTP URLs. - @param environment: A dictionary of environment variables to set on - worker nodes. - @param batchSize: The number of Python objects represented as a single - Java object. Set 1 to disable batching or -1 to use an - unlimited batch size. - """ - self.master = master - self.jobName = jobName - self.sparkHome = sparkHome or None # None becomes null in Py4J - self.environment = environment or {} - self.batchSize = batchSize # -1 represents a unlimited batch size - - # Create the Java SparkContext through Py4J - empty_string_array = self.gateway.new_array(self.jvm.String, 0) - self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, - empty_string_array) - - self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') - # Broadcast's __reduce__ method stores Broadcast instances here. - # This allows other code to determine which Broadcast instances have - # been pickled, so it can determine which Java broadcast objects to - # send. - self._pickled_broadcast_vars = set() - - # Deploy any code dependencies specified in the constructor - for path in (pyFiles or []): - self.addPyFile(path) - - @property - def defaultParallelism(self): - """ - Default level of parallelism to use when not given by user (e.g. for - reduce tasks) - """ - return self._jsc.sc().defaultParallelism() - - def __del__(self): - if self._jsc: - self._jsc.stop() - - def stop(self): - """ - Shut down the SparkContext. - """ - self._jsc.stop() - self._jsc = None - - def parallelize(self, c, numSlices=None): - """ - Distribute a local Python collection to form an RDD. - """ - numSlices = numSlices or self.defaultParallelism - # Calling the Java parallelize() method with an ArrayList is too slow, - # because it sends O(n) Py4J commands. As an alternative, serialized - # objects are written to a file and loaded through textFile(). - tempFile = NamedTemporaryFile(delete=False) - atexit.register(lambda: os.unlink(tempFile.name)) - if self.batchSize != 1: - c = batched(c, self.batchSize) - for x in c: - write_with_length(dump_pickle(x), tempFile) - tempFile.close() - jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) - return RDD(jrdd, self) - - def textFile(self, name, minSplits=None): - """ - Read a text file from HDFS, a local file system (available on all - nodes), or any Hadoop-supported file system URI, and return it as an - RDD of Strings. - """ - minSplits = minSplits or min(self.defaultParallelism, 2) - jrdd = self._jsc.textFile(name, minSplits) - return RDD(jrdd, self) - - def union(self, rdds): - """ - Build the union of a list of RDDs. - """ - first = rdds[0]._jrdd - rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self.gateway._gateway_client) - return RDD(self._jsc.union(first, rest), self) - - def broadcast(self, value): - """ - Broadcast a read-only variable to the cluster, returning a C{Broadcast} - object for reading it in distributed functions. The variable will be - sent to each cluster only once. - """ - jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) - return Broadcast(jbroadcast.id(), value, jbroadcast, - self._pickled_broadcast_vars) - - def addFile(self, path): - """ - Add a file to be downloaded into the working directory of this Spark - job on every node. The C{path} passed can be either a local file, - a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, - HTTPS or FTP URI. - """ - self._jsc.sc().addFile(path) - - def clearFiles(self): - """ - Clear the job's list of files added by L{addFile} or L{addPyFile} so - that they do not get downloaded to any new nodes. - """ - # TODO: remove added .py or .zip files from the PYTHONPATH? - self._jsc.sc().clearFiles() - - def addPyFile(self, path): - """ - Add a .py or .zip dependency for all tasks to be executed on this - SparkContext in the future. The C{path} passed can be either a local - file, a file in HDFS (or other Hadoop-supported filesystems), or an - HTTP, HTTPS or FTP URI. - """ - self.addFile(path) - filename = path.split("/")[-1] - os.environ["PYTHONPATH"] = \ - "%s:%s" % (filename, os.environ["PYTHONPATH"]) diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py deleted file mode 100644 index 2329e536cc..0000000000 --- a/pyspark/pyspark/java_gateway.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -import sys -from subprocess import Popen, PIPE -from threading import Thread -from py4j.java_gateway import java_import, JavaGateway, GatewayClient - - -SPARK_HOME = os.environ["SPARK_HOME"] - - -def launch_gateway(): - # Launch the Py4j gateway using Spark's run command so that we pick up the - # proper classpath and SPARK_MEM settings from spark-env.sh - command = [os.path.join(SPARK_HOME, "run"), "py4j.GatewayServer", - "--die-on-broken-pipe", "0"] - proc = Popen(command, stdout=PIPE, stdin=PIPE) - # Determine which ephemeral port the server started on: - port = int(proc.stdout.readline()) - # Create a thread to echo output from the GatewayServer, which is required - # for Java log output to show up: - class EchoOutputThread(Thread): - def __init__(self, stream): - Thread.__init__(self) - self.daemon = True - self.stream = stream - - def run(self): - while True: - line = self.stream.readline() - sys.stderr.write(line) - EchoOutputThread(proc.stdout).start() - # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=port), auto_convert=False) - # Import the classes used by PySpark - java_import(gateway.jvm, "spark.api.java.*") - java_import(gateway.jvm, "spark.api.python.*") - java_import(gateway.jvm, "scala.Tuple2") - return gateway diff --git a/pyspark/pyspark/join.py b/pyspark/pyspark/join.py deleted file mode 100644 index 7036c47980..0000000000 --- a/pyspark/pyspark/join.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Copyright (c) 2011, Douban Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - - * Neither the name of the Douban Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -""" - - -def _do_python_join(rdd, other, numSplits, dispatch): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) - return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch) - - -def python_join(rdd, other, numSplits): - def dispatch(seq): - vbuf, wbuf = [], [] - for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - return [(v, w) for v in vbuf for w in wbuf] - return _do_python_join(rdd, other, numSplits, dispatch) - - -def python_right_outer_join(rdd, other, numSplits): - def dispatch(seq): - vbuf, wbuf = [], [] - for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - if not vbuf: - vbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] - return _do_python_join(rdd, other, numSplits, dispatch) - - -def python_left_outer_join(rdd, other, numSplits): - def dispatch(seq): - vbuf, wbuf = [], [] - for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - if not wbuf: - wbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] - return _do_python_join(rdd, other, numSplits, dispatch) - - -def python_cogroup(rdd, other, numSplits): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) - def dispatch(seq): - vbuf, wbuf = [], [] - for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - return (vbuf, wbuf) - return vs.union(ws).groupByKey(numSplits).mapValues(dispatch) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py deleted file mode 100644 index cbffb6cc1f..0000000000 --- a/pyspark/pyspark/rdd.py +++ /dev/null @@ -1,713 +0,0 @@ -import atexit -from base64 import standard_b64encode as b64enc -import copy -from collections import defaultdict -from itertools import chain, ifilter, imap, product -import operator -import os -import shlex -from subprocess import Popen, PIPE -from tempfile import NamedTemporaryFile -from threading import Thread - -from pyspark import cloudpickle -from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ - read_from_pickle_file -from pyspark.join import python_join, python_left_outer_join, \ - python_right_outer_join, python_cogroup - -from py4j.java_collections import ListConverter, MapConverter - - -__all__ = ["RDD"] - - -class RDD(object): - """ - A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. - Represents an immutable, partitioned collection of elements that can be - operated on in parallel. - """ - - def __init__(self, jrdd, ctx): - self._jrdd = jrdd - self.is_cached = False - self.ctx = ctx - - @property - def context(self): - """ - The L{SparkContext} that this RDD was created on. - """ - return self.ctx - - def cache(self): - """ - Persist this RDD with the default storage level (C{MEMORY_ONLY}). - """ - self.is_cached = True - self._jrdd.cache() - return self - - # TODO persist(self, storageLevel) - - def map(self, f, preservesPartitioning=False): - """ - Return a new RDD containing the distinct elements in this RDD. - """ - def func(iterator): return imap(f, iterator) - return PipelinedRDD(self, func, preservesPartitioning) - - def flatMap(self, f, preservesPartitioning=False): - """ - Return a new RDD by first applying a function to all elements of this - RDD, and then flattening the results. - - >>> rdd = sc.parallelize([2, 3, 4]) - >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) - [1, 1, 1, 2, 2, 3] - >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) - [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] - """ - def func(iterator): return chain.from_iterable(imap(f, iterator)) - return self.mapPartitions(func, preservesPartitioning) - - def mapPartitions(self, f, preservesPartitioning=False): - """ - Return a new RDD by applying a function to each partition of this RDD. - - >>> rdd = sc.parallelize([1, 2, 3, 4], 2) - >>> def f(iterator): yield sum(iterator) - >>> rdd.mapPartitions(f).collect() - [3, 7] - """ - return PipelinedRDD(self, f, preservesPartitioning) - - # TODO: mapPartitionsWithSplit - - def filter(self, f): - """ - Return a new RDD containing only the elements that satisfy a predicate. - - >>> rdd = sc.parallelize([1, 2, 3, 4, 5]) - >>> rdd.filter(lambda x: x % 2 == 0).collect() - [2, 4] - """ - def func(iterator): return ifilter(f, iterator) - return self.mapPartitions(func) - - def distinct(self): - """ - Return a new RDD containing the distinct elements in this RDD. - - >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) - [1, 2, 3] - """ - return self.map(lambda x: (x, "")) \ - .reduceByKey(lambda x, _: x) \ - .map(lambda (x, _): x) - - # TODO: sampling needs to be re-implemented due to Batch - #def sample(self, withReplacement, fraction, seed): - # jrdd = self._jrdd.sample(withReplacement, fraction, seed) - # return RDD(jrdd, self.ctx) - - #def takeSample(self, withReplacement, num, seed): - # vals = self._jrdd.takeSample(withReplacement, num, seed) - # return [load_pickle(bytes(x)) for x in vals] - - def union(self, other): - """ - Return the union of this RDD and another one. - - >>> rdd = sc.parallelize([1, 1, 2, 3]) - >>> rdd.union(rdd).collect() - [1, 1, 2, 3, 1, 1, 2, 3] - """ - return RDD(self._jrdd.union(other._jrdd), self.ctx) - - def __add__(self, other): - """ - Return the union of this RDD and another one. - - >>> rdd = sc.parallelize([1, 1, 2, 3]) - >>> (rdd + rdd).collect() - [1, 1, 2, 3, 1, 1, 2, 3] - """ - if not isinstance(other, RDD): - raise TypeError - return self.union(other) - - # TODO: sort - - def glom(self): - """ - Return an RDD created by coalescing all elements within each partition - into a list. - - >>> rdd = sc.parallelize([1, 2, 3, 4], 2) - >>> sorted(rdd.glom().collect()) - [[1, 2], [3, 4]] - """ - def func(iterator): yield list(iterator) - return self.mapPartitions(func) - - def cartesian(self, other): - """ - Return the Cartesian product of this RDD and another one, that is, the - RDD of all pairs of elements C{(a, b)} where C{a} is in C{self} and - C{b} is in C{other}. - - >>> rdd = sc.parallelize([1, 2]) - >>> sorted(rdd.cartesian(rdd).collect()) - [(1, 1), (1, 2), (2, 1), (2, 2)] - """ - # Due to batching, we can't use the Java cartesian method. - java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) - def unpack_batches(pair): - (x, y) = pair - if type(x) == Batch or type(y) == Batch: - xs = x.items if type(x) == Batch else [x] - ys = y.items if type(y) == Batch else [y] - for pair in product(xs, ys): - yield pair - else: - yield pair - return java_cartesian.flatMap(unpack_batches) - - def groupBy(self, f, numSplits=None): - """ - Return an RDD of grouped items. - - >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) - >>> result = rdd.groupBy(lambda x: x % 2).collect() - >>> sorted([(x, sorted(y)) for (x, y) in result]) - [(0, [2, 8]), (1, [1, 1, 3, 5])] - """ - return self.map(lambda x: (f(x), x)).groupByKey(numSplits) - - def pipe(self, command, env={}): - """ - Return an RDD created by piping elements to a forked external process. - - >>> sc.parallelize([1, 2, 3]).pipe('cat').collect() - ['1', '2', '3'] - """ - def func(iterator): - pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) - def pipe_objs(out): - for obj in iterator: - out.write(str(obj).rstrip('\n') + '\n') - out.close() - Thread(target=pipe_objs, args=[pipe.stdin]).start() - return (x.rstrip('\n') for x in pipe.stdout) - return self.mapPartitions(func) - - def foreach(self, f): - """ - Applies a function to all elements of this RDD. - - >>> def f(x): print x - >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) - """ - self.map(f).collect() # Force evaluation - - def collect(self): - """ - Return a list that contains all of the elements in this RDD. - """ - picklesInJava = self._jrdd.collect().iterator() - return list(self._collect_iterator_through_file(picklesInJava)) - - def _collect_iterator_through_file(self, iterator): - # Transferring lots of data through Py4J can be slow because - # socket.readline() is inefficient. Instead, we'll dump the data to a - # file and read it back. - tempFile = NamedTemporaryFile(delete=False) - tempFile.close() - def clean_up_file(): - try: os.unlink(tempFile.name) - except: pass - atexit.register(clean_up_file) - self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) - # Read the data into Python and deserialize it: - with open(tempFile.name, 'rb') as tempFile: - for item in read_from_pickle_file(tempFile): - yield item - os.unlink(tempFile.name) - - def reduce(self, f): - """ - Reduces the elements of this RDD using the specified associative binary - operator. - - >>> from operator import add - >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add) - 15 - >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) - 10 - """ - def func(iterator): - acc = None - for obj in iterator: - if acc is None: - acc = obj - else: - acc = f(obj, acc) - if acc is not None: - yield acc - vals = self.mapPartitions(func).collect() - return reduce(f, vals) - - def fold(self, zeroValue, op): - """ - Aggregate the elements of each partition, and then the results for all - the partitions, using a given associative function and a neutral "zero - value." - - The function C{op(t1, t2)} is allowed to modify C{t1} and return it - as its result value to avoid object allocation; however, it should not - modify C{t2}. - - >>> from operator import add - >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) - 15 - """ - def func(iterator): - acc = zeroValue - for obj in iterator: - acc = op(obj, acc) - yield acc - vals = self.mapPartitions(func).collect() - return reduce(op, vals, zeroValue) - - # TODO: aggregate - - def sum(self): - """ - Add up the elements in this RDD. - - >>> sc.parallelize([1.0, 2.0, 3.0]).sum() - 6.0 - """ - return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) - - def count(self): - """ - Return the number of elements in this RDD. - - >>> sc.parallelize([2, 3, 4]).count() - 3 - """ - return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() - - def countByValue(self): - """ - Return the count of each unique value in this RDD as a dictionary of - (value, count) pairs. - - >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items()) - [(1, 2), (2, 3)] - """ - def countPartition(iterator): - counts = defaultdict(int) - for obj in iterator: - counts[obj] += 1 - yield counts - def mergeMaps(m1, m2): - for (k, v) in m2.iteritems(): - m1[k] += v - return m1 - return self.mapPartitions(countPartition).reduce(mergeMaps) - - def take(self, num): - """ - Take the first num elements of the RDD. - - This currently scans the partitions *one by one*, so it will be slow if - a lot of partitions are required. In that case, use L{collect} to get - the whole RDD instead. - - >>> sc.parallelize([2, 3, 4, 5, 6]).take(2) - [2, 3] - >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) - [2, 3, 4, 5, 6] - """ - items = [] - splits = self._jrdd.splits() - taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) - while len(items) < num and splits: - split = splits.pop(0) - iterator = self._jrdd.iterator(split, taskContext) - items.extend(self._collect_iterator_through_file(iterator)) - return items[:num] - - def first(self): - """ - Return the first element in this RDD. - - >>> sc.parallelize([2, 3, 4]).first() - 2 - """ - return self.take(1)[0] - - def saveAsTextFile(self, path): - """ - Save this RDD as a text file, using string representations of elements. - - >>> tempFile = NamedTemporaryFile(delete=True) - >>> tempFile.close() - >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name) - >>> from fileinput import input - >>> from glob import glob - >>> ''.join(input(glob(tempFile.name + "/part-0000*"))) - '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' - """ - def func(iterator): - return (str(x).encode("utf-8") for x in iterator) - keyed = PipelinedRDD(self, func) - keyed._bypass_serializer = True - keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path) - - # Pair functions - - def collectAsMap(self): - """ - Return the key-value pairs in this RDD to the master as a dictionary. - - >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() - >>> m[1] - 2 - >>> m[3] - 4 - """ - return dict(self.collect()) - - def reduceByKey(self, func, numSplits=None): - """ - Merge the values for each key using an associative reduce function. - - This will also perform the merging locally on each mapper before - sending results to a reducer, similarly to a "combiner" in MapReduce. - - Output will be hash-partitioned with C{numSplits} splits, or the - default parallelism level if C{numSplits} is not specified. - - >>> from operator import add - >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(rdd.reduceByKey(add).collect()) - [('a', 2), ('b', 1)] - """ - return self.combineByKey(lambda x: x, func, func, numSplits) - - def reduceByKeyLocally(self, func): - """ - Merge the values for each key using an associative reduce function, but - return the results immediately to the master as a dictionary. - - This will also perform the merging locally on each mapper before - sending results to a reducer, similarly to a "combiner" in MapReduce. - - >>> from operator import add - >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(rdd.reduceByKeyLocally(add).items()) - [('a', 2), ('b', 1)] - """ - def reducePartition(iterator): - m = {} - for (k, v) in iterator: - m[k] = v if k not in m else func(m[k], v) - yield m - def mergeMaps(m1, m2): - for (k, v) in m2.iteritems(): - m1[k] = v if k not in m1 else func(m1[k], v) - return m1 - return self.mapPartitions(reducePartition).reduce(mergeMaps) - - def countByKey(self): - """ - Count the number of elements for each key, and return the result to the - master as a dictionary. - - >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(rdd.countByKey().items()) - [('a', 2), ('b', 1)] - """ - return self.map(lambda x: x[0]).countByValue() - - def join(self, other, numSplits=None): - """ - Return an RDD containing all pairs of elements with matching keys in - C{self} and C{other}. - - Each pair of elements will be returned as a (k, (v1, v2)) tuple, where - (k, v1) is in C{self} and (k, v2) is in C{other}. - - Performs a hash join across the cluster. - - >>> x = sc.parallelize([("a", 1), ("b", 4)]) - >>> y = sc.parallelize([("a", 2), ("a", 3)]) - >>> sorted(x.join(y).collect()) - [('a', (1, 2)), ('a', (1, 3))] - """ - return python_join(self, other, numSplits) - - def leftOuterJoin(self, other, numSplits=None): - """ - Perform a left outer join of C{self} and C{other}. - - For each element (k, v) in C{self}, the resulting RDD will either - contain all pairs (k, (v, w)) for w in C{other}, or the pair - (k, (v, None)) if no elements in other have key k. - - Hash-partitions the resulting RDD into the given number of partitions. - - >>> x = sc.parallelize([("a", 1), ("b", 4)]) - >>> y = sc.parallelize([("a", 2)]) - >>> sorted(x.leftOuterJoin(y).collect()) - [('a', (1, 2)), ('b', (4, None))] - """ - return python_left_outer_join(self, other, numSplits) - - def rightOuterJoin(self, other, numSplits=None): - """ - Perform a right outer join of C{self} and C{other}. - - For each element (k, w) in C{other}, the resulting RDD will either - contain all pairs (k, (v, w)) for v in this, or the pair (k, (None, w)) - if no elements in C{self} have key k. - - Hash-partitions the resulting RDD into the given number of partitions. - - >>> x = sc.parallelize([("a", 1), ("b", 4)]) - >>> y = sc.parallelize([("a", 2)]) - >>> sorted(y.rightOuterJoin(x).collect()) - [('a', (2, 1)), ('b', (None, 4))] - """ - return python_right_outer_join(self, other, numSplits) - - # TODO: add option to control map-side combining - def partitionBy(self, numSplits, hashFunc=hash): - """ - Return a copy of the RDD partitioned using the specified partitioner. - - >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) - >>> sets = pairs.partitionBy(2).glom().collect() - >>> set(sets[0]).intersection(set(sets[1])) - set([]) - """ - if numSplits is None: - numSplits = self.ctx.defaultParallelism - # Transferring O(n) objects to Java is too expensive. Instead, we'll - # form the hash buckets in Python, transferring O(numSplits) objects - # to Java. Each object is a (splitNumber, [objects]) pair. - def add_shuffle_key(iterator): - buckets = defaultdict(list) - for (k, v) in iterator: - buckets[hashFunc(k) % numSplits].append((k, v)) - for (split, items) in buckets.iteritems(): - yield str(split) - yield dump_pickle(Batch(items)) - keyed = PipelinedRDD(self, add_shuffle_key) - keyed._bypass_serializer = True - pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) - jrdd = pairRDD.partitionBy(partitioner) - jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - return RDD(jrdd, self.ctx) - - # TODO: add control over map-side aggregation - def combineByKey(self, createCombiner, mergeValue, mergeCombiners, - numSplits=None): - """ - Generic function to combine the elements for each key using a custom - set of aggregation functions. - - Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined - type" C. Note that V and C can be different -- for example, one might - group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]). - - Users provide three functions: - - - C{createCombiner}, which turns a V into a C (e.g., creates - a one-element list) - - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of - a list) - - C{mergeCombiners}, to combine two C's into a single one. - - In addition, users can control the partitioning of the output RDD. - - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> def f(x): return x - >>> def add(a, b): return a + str(b) - >>> sorted(x.combineByKey(str, add, add).collect()) - [('a', '11'), ('b', '1')] - """ - if numSplits is None: - numSplits = self.ctx.defaultParallelism - def combineLocally(iterator): - combiners = {} - for (k, v) in iterator: - if k not in combiners: - combiners[k] = createCombiner(v) - else: - combiners[k] = mergeValue(combiners[k], v) - return combiners.iteritems() - locally_combined = self.mapPartitions(combineLocally) - shuffled = locally_combined.partitionBy(numSplits) - def _mergeCombiners(iterator): - combiners = {} - for (k, v) in iterator: - if not k in combiners: - combiners[k] = v - else: - combiners[k] = mergeCombiners(combiners[k], v) - return combiners.iteritems() - return shuffled.mapPartitions(_mergeCombiners) - - # TODO: support variant with custom partitioner - def groupByKey(self, numSplits=None): - """ - Group the values for each key in the RDD into a single sequence. - Hash-partitions the resulting RDD with into numSplits partitions. - - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(x.groupByKey().collect()) - [('a', [1, 1]), ('b', [1])] - """ - - def createCombiner(x): - return [x] - - def mergeValue(xs, x): - xs.append(x) - return xs - - def mergeCombiners(a, b): - return a + b - - return self.combineByKey(createCombiner, mergeValue, mergeCombiners, - numSplits) - - # TODO: add tests - def flatMapValues(self, f): - """ - Pass each value in the key-value pair RDD through a flatMap function - without changing the keys; this also retains the original RDD's - partitioning. - """ - flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) - return self.flatMap(flat_map_fn, preservesPartitioning=True) - - def mapValues(self, f): - """ - Pass each value in the key-value pair RDD through a map function - without changing the keys; this also retains the original RDD's - partitioning. - """ - map_values_fn = lambda (k, v): (k, f(v)) - return self.map(map_values_fn, preservesPartitioning=True) - - # TODO: support varargs cogroup of several RDDs. - def groupWith(self, other): - """ - Alias for cogroup. - """ - return self.cogroup(other) - - # TODO: add variant with custom parittioner - def cogroup(self, other, numSplits=None): - """ - For each key k in C{self} or C{other}, return a resulting RDD that - contains a tuple with the list of values for that key in C{self} as well - as C{other}. - - >>> x = sc.parallelize([("a", 1), ("b", 4)]) - >>> y = sc.parallelize([("a", 2)]) - >>> sorted(x.cogroup(y).collect()) - [('a', ([1], [2])), ('b', ([4], []))] - """ - return python_cogroup(self, other, numSplits) - - # TODO: `lookup` is disabled because we can't make direct comparisons based - # on the key; we need to compare the hash of the key to the hash of the - # keys in the pairs. This could be an expensive operation, since those - # hashes aren't retained. - - -class PipelinedRDD(RDD): - """ - Pipelined maps: - >>> rdd = sc.parallelize([1, 2, 3, 4]) - >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() - [4, 8, 12, 16] - >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() - [4, 8, 12, 16] - - Pipelined reduces: - >>> from operator import add - >>> rdd.map(lambda x: 2 * x).reduce(add) - 20 - >>> rdd.flatMap(lambda x: [x, x]).reduce(add) - 20 - """ - def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and not prev.is_cached: - prev_func = prev.func - def pipeline_func(iterator): - return func(prev_func(iterator)) - self.func = pipeline_func - self.preservesPartitioning = \ - prev.preservesPartitioning and preservesPartitioning - self._prev_jrdd = prev._prev_jrdd - else: - self.func = func - self.preservesPartitioning = preservesPartitioning - self._prev_jrdd = prev._jrdd - self.is_cached = False - self.ctx = prev.ctx - self.prev = prev - self._jrdd_val = None - self._bypass_serializer = False - - @property - def _jrdd(self): - if self._jrdd_val: - return self._jrdd_val - func = self.func - if not self._bypass_serializer and self.ctx.batchSize != 1: - oldfunc = self.func - batchSize = self.ctx.batchSize - def batched_func(iterator): - return batched(oldfunc(iterator), batchSize) - func = batched_func - cmds = [func, self._bypass_serializer] - pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], - self.ctx.gateway._gateway_client) - self.ctx._pickled_broadcast_vars.clear() - class_manifest = self._prev_jrdd.classManifest() - env = copy.copy(self.ctx.environment) - env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "") - env = MapConverter().convert(env, self.ctx.gateway._gateway_client) - python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, class_manifest) - self._jrdd_val = python_rdd.asJavaRDD() - return self._jrdd_val - - -def _test(): - import doctest - from pyspark.context import SparkContext - globs = globals().copy() - # The small batch size here ensures that we see multiple batches, - # even in these small test examples: - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - doctest.testmod(globs=globs) - globs['sc'].stop() - - -if __name__ == "__main__": - _test() diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py deleted file mode 100644 index 9a5151ea00..0000000000 --- a/pyspark/pyspark/serializers.py +++ /dev/null @@ -1,78 +0,0 @@ -import struct -import cPickle - - -class Batch(object): - """ - Used to store multiple RDD entries as a single Java object. - - This relieves us from having to explicitly track whether an RDD - is stored as batches of objects and avoids problems when processing - the union() of batched and unbatched RDDs (e.g. the union() of textFile() - with another RDD). - """ - def __init__(self, items): - self.items = items - - -def batched(iterator, batchSize): - if batchSize == -1: # unlimited batch size - yield Batch(list(iterator)) - else: - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == batchSize: - yield Batch(items) - items = [] - count = 0 - if items: - yield Batch(items) - - -def dump_pickle(obj): - return cPickle.dumps(obj, 2) - - -load_pickle = cPickle.loads - - -def read_long(stream): - length = stream.read(8) - if length == "": - raise EOFError - return struct.unpack("!q", length)[0] - - -def read_int(stream): - length = stream.read(4) - if length == "": - raise EOFError - return struct.unpack("!i", length)[0] - -def write_with_length(obj, stream): - stream.write(struct.pack("!i", len(obj))) - stream.write(obj) - - -def read_with_length(stream): - length = read_int(stream) - obj = stream.read(length) - if obj == "": - raise EOFError - return obj - - -def read_from_pickle_file(stream): - try: - while True: - obj = load_pickle(read_with_length(stream)) - if type(obj) == Batch: # We don't care about inheritance - for item in obj.items: - yield item - else: - yield obj - except EOFError: - return diff --git a/pyspark/pyspark/shell.py b/pyspark/pyspark/shell.py deleted file mode 100644 index bd39b0283f..0000000000 --- a/pyspark/pyspark/shell.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -An interactive shell. -""" -import optparse # I prefer argparse, but it's not included with Python < 2.7 -import code -import sys - -from pyspark.context import SparkContext - - -def main(master='local', ipython=False): - sc = SparkContext(master, 'PySparkShell') - user_ns = {'sc' : sc} - banner = "Spark context avaiable as sc." - if ipython: - import IPython - IPython.embed(user_ns=user_ns, banner2=banner) - else: - print banner - code.interact(local=user_ns) - - -if __name__ == '__main__': - usage = "usage: %prog [options] master" - parser = optparse.OptionParser(usage=usage) - parser.add_option("-i", "--ipython", help="Run IPython shell", - action="store_true") - (options, args) = parser.parse_args() - if len(sys.argv) > 1: - master = args[0] - else: - master = 'local' - main(master, options.ipython) diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py deleted file mode 100644 index 9f6b507dbd..0000000000 --- a/pyspark/pyspark/worker.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Worker that receives input from Piped RDD. -""" -import sys -from base64 import standard_b64decode -# CloudPickler needs to be imported so that depicklers are registered using the -# copy_reg module. -from pyspark.broadcast import Broadcast, _broadcastRegistry -from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import write_with_length, read_with_length, \ - read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file - - -# Redirect stdout to stderr so that users must return values from functions. -old_stdout = sys.stdout -sys.stdout = sys.stderr - - -def load_obj(): - return load_pickle(standard_b64decode(sys.stdin.readline().strip())) - - -def main(): - num_broadcast_variables = read_int(sys.stdin) - for _ in range(num_broadcast_variables): - bid = read_long(sys.stdin) - value = read_with_length(sys.stdin) - _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) - func = load_obj() - bypassSerializer = load_obj() - if bypassSerializer: - dumps = lambda x: x - else: - dumps = dump_pickle - for obj in func(read_from_pickle_file(sys.stdin)): - write_with_length(dumps(obj), old_stdout) - - -if __name__ == '__main__': - main() diff --git a/pyspark/run-pyspark b/pyspark/run-pyspark deleted file mode 100755 index 4d10fbea8b..0000000000 --- a/pyspark/run-pyspark +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env bash - -# Figure out where the Scala framework is installed -FWDIR="$(cd `dirname $0`; cd ../; pwd)" - -# Export this as SPARK_HOME -export SPARK_HOME="$FWDIR" - -# Load environment variables from conf/spark-env.sh, if it exists -if [ -e $FWDIR/conf/spark-env.sh ] ; then - . $FWDIR/conf/spark-env.sh -fi - -# Figure out which Python executable to use -if [ -z "$PYSPARK_PYTHON" ] ; then - PYSPARK_PYTHON="python" -fi -export PYSPARK_PYTHON - -# Add the PySpark classes to the Python path: -export PYTHONPATH=$SPARK_HOME/pyspark/:$PYTHONPATH - -# Launch with `scala` by default: -if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then - export SPARK_LAUNCH_WITH_SCALA=1 -fi - -exec "$PYSPARK_PYTHON" "$@" diff --git a/python/.gitignore b/python/.gitignore new file mode 100644 index 0000000000..5c56e638f9 --- /dev/null +++ b/python/.gitignore @@ -0,0 +1,2 @@ +*.pyc +docs/ diff --git a/python/epydoc.conf b/python/epydoc.conf new file mode 100644 index 0000000000..91ac984ba2 --- /dev/null +++ b/python/epydoc.conf @@ -0,0 +1,19 @@ +[epydoc] # Epydoc section marker (required by ConfigParser) + +# Information about the project. +name: PySpark +url: http://spark-project.org + +# The list of modules to document. Modules can be named using +# dotted names, module filenames, or package directory names. +# This option may be repeated. +modules: pyspark + +# Write html output to the directory "apidocs" +output: html +target: docs/ + +private: no + +exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers + pyspark.java_gateway pyspark.examples pyspark.shell diff --git a/python/examples/kmeans.py b/python/examples/kmeans.py new file mode 100644 index 0000000000..ad2be21178 --- /dev/null +++ b/python/examples/kmeans.py @@ -0,0 +1,52 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +import sys + +import numpy as np +from pyspark import SparkContext + + +def parseVector(line): + return np.array([float(x) for x in line.split(' ')]) + + +def closestPoint(p, centers): + bestIndex = 0 + closest = float("+inf") + for i in range(len(centers)): + tempDist = np.sum((p - centers[i]) ** 2) + if tempDist < closest: + closest = tempDist + bestIndex = i + return bestIndex + + +if __name__ == "__main__": + if len(sys.argv) < 5: + print >> sys.stderr, \ + "Usage: PythonKMeans " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + lines = sc.textFile(sys.argv[2]) + data = lines.map(parseVector).cache() + K = int(sys.argv[3]) + convergeDist = float(sys.argv[4]) + + kPoints = data.takeSample(False, K, 34) + tempDist = 1.0 + + while tempDist > convergeDist: + closest = data.map( + lambda p : (closestPoint(p, kPoints), (p, 1))) + pointStats = closest.reduceByKey( + lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) + newPoints = pointStats.map( + lambda (x, (y, z)): (x, y / z)).collect() + + tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) + + for (x, y) in newPoints: + kPoints[x] = y + + print "Final centers: " + str(kPoints) diff --git a/python/examples/logistic_regression.py b/python/examples/logistic_regression.py new file mode 100755 index 0000000000..f13698a86f --- /dev/null +++ b/python/examples/logistic_regression.py @@ -0,0 +1,57 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +from collections import namedtuple +from math import exp +from os.path import realpath +import sys + +import numpy as np +from pyspark import SparkContext + + +N = 100000 # Number of data points +D = 10 # Number of dimensions +R = 0.7 # Scaling factor +ITERATIONS = 5 +np.random.seed(42) + + +DataPoint = namedtuple("DataPoint", ['x', 'y']) +from lr import DataPoint # So that DataPoint is properly serialized + + +def generateData(): + def generatePoint(i): + y = -1 if i % 2 == 0 else 1 + x = np.random.normal(size=D) + (y * R) + return DataPoint(x, y) + return [generatePoint(i) for i in range(N)] + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonLR []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)]) + slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 + points = sc.parallelize(generateData(), slices).cache() + + # Initialize w to a random value + w = 2 * np.random.ranf(size=D) - 1 + print "Initial w: " + str(w) + + def add(x, y): + x += y + return x + + for i in range(1, ITERATIONS + 1): + print "On iteration %i" % i + + gradient = points.map(lambda p: + (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x + ).reduce(add) + w -= gradient + + print "Final w: " + str(w) diff --git a/python/examples/pi.py b/python/examples/pi.py new file mode 100644 index 0000000000..127cba029b --- /dev/null +++ b/python/examples/pi.py @@ -0,0 +1,21 @@ +import sys +from random import random +from operator import add + +from pyspark import SparkContext + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonPi []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonPi") + slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 + n = 100000 * slices + def f(_): + x = random() * 2 - 1 + y = random() * 2 - 1 + return 1 if x ** 2 + y ** 2 < 1 else 0 + count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) + print "Pi is roughly %f" % (4.0 * count / n) diff --git a/python/examples/transitive_closure.py b/python/examples/transitive_closure.py new file mode 100644 index 0000000000..73f7f8fbaf --- /dev/null +++ b/python/examples/transitive_closure.py @@ -0,0 +1,50 @@ +import sys +from random import Random + +from pyspark import SparkContext + +numEdges = 200 +numVertices = 100 +rand = Random(42) + + +def generateGraph(): + edges = set() + while len(edges) < numEdges: + src = rand.randrange(0, numEdges) + dst = rand.randrange(0, numEdges) + if src != dst: + edges.add((src, dst)) + return edges + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonTC []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonTC") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + tc = sc.parallelize(generateGraph(), slices).cache() + + # Linear transitive closure: each round grows paths by one edge, + # by joining the graph's edges with the already-discovered paths. + # e.g. join the path (y, z) from the TC with the edge (x, y) from + # the graph to obtain the path (x, z). + + # Because join() joins on keys, the edges are stored in reversed order. + edges = tc.map(lambda (x, y): (y, x)) + + oldCount = 0L + nextCount = tc.count() + while True: + oldCount = nextCount + # Perform the join, obtaining an RDD of (y, (z, x)) pairs, + # then project the result to obtain the new (x, z) paths. + new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) + tc = tc.union(new_edges).distinct().cache() + nextCount = tc.count() + if nextCount == oldCount: + break + + print "TC has %i edges" % tc.count() diff --git a/python/examples/wordcount.py b/python/examples/wordcount.py new file mode 100644 index 0000000000..857160624b --- /dev/null +++ b/python/examples/wordcount.py @@ -0,0 +1,19 @@ +import sys +from operator import add + +from pyspark import SparkContext + + +if __name__ == "__main__": + if len(sys.argv) < 3: + print >> sys.stderr, \ + "Usage: PythonWordCount " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonWordCount") + lines = sc.textFile(sys.argv[2], 1) + counts = lines.flatMap(lambda x: x.split(' ')) \ + .map(lambda x: (x, 1)) \ + .reduceByKey(add) + output = counts.collect() + for (word, count) in output: + print "%s : %i" % (word, count) diff --git a/python/lib/PY4J_LICENSE.txt b/python/lib/PY4J_LICENSE.txt new file mode 100644 index 0000000000..a70279ca14 --- /dev/null +++ b/python/lib/PY4J_LICENSE.txt @@ -0,0 +1,27 @@ + +Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +- The name of the author may not be used to endorse or promote products +derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/python/lib/PY4J_VERSION.txt b/python/lib/PY4J_VERSION.txt new file mode 100644 index 0000000000..04a0cd52a8 --- /dev/null +++ b/python/lib/PY4J_VERSION.txt @@ -0,0 +1 @@ +b7924aabe9c5e63f0a4d8bbd17019534c7ec014e diff --git a/python/lib/py4j0.7.egg b/python/lib/py4j0.7.egg new file mode 100644 index 0000000000..f8a339d8ee Binary files /dev/null and b/python/lib/py4j0.7.egg differ diff --git a/python/lib/py4j0.7.jar b/python/lib/py4j0.7.jar new file mode 100644 index 0000000000..73b7ddb7d1 Binary files /dev/null and b/python/lib/py4j0.7.jar differ diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py new file mode 100644 index 0000000000..c595ae0842 --- /dev/null +++ b/python/pyspark/__init__.py @@ -0,0 +1,20 @@ +""" +PySpark is a Python API for Spark. + +Public classes: + + - L{SparkContext} + Main entry point for Spark functionality. + - L{RDD} + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.egg")) + + +from pyspark.context import SparkContext +from pyspark.rdd import RDD + + +__all__ = ["SparkContext", "RDD"] diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py new file mode 100644 index 0000000000..93876fa738 --- /dev/null +++ b/python/pyspark/broadcast.py @@ -0,0 +1,48 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> b = sc.broadcast([1, 2, 3, 4, 5]) +>>> b.value +[1, 2, 3, 4, 5] + +>>> from pyspark.broadcast import _broadcastRegistry +>>> _broadcastRegistry[b.bid] = b +>>> from cPickle import dumps, loads +>>> loads(dumps(b)).value +[1, 2, 3, 4, 5] + +>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() +[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] + +>>> large_broadcast = sc.broadcast(list(range(10000))) +""" +# Holds broadcasted data received from Java, keyed by its id. +_broadcastRegistry = {} + + +def _from_id(bid): + from pyspark.broadcast import _broadcastRegistry + if bid not in _broadcastRegistry: + raise Exception("Broadcast variable '%s' not loaded!" % bid) + return _broadcastRegistry[bid] + + +class Broadcast(object): + def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): + self.value = value + self.bid = bid + self._jbroadcast = java_broadcast + self._pickle_registry = pickle_registry + + def __reduce__(self): + self._pickle_registry.add(self) + return (_from_id, (self.bid, )) + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py new file mode 100644 index 0000000000..6a7c23a069 --- /dev/null +++ b/python/pyspark/cloudpickle.py @@ -0,0 +1,974 @@ +""" +This class is defined to override standard pickle functionality + +The goals of it follow: +-Serialize lambdas and nested functions to compiled byte code +-Deal with main module correctly +-Deal with other non-serializable objects + +It does not include an unpickler, as standard python unpickling suffices. + +This module was extracted from the `cloud` package, developed by `PiCloud, Inc. +`_. + +Copyright (c) 2012, Regents of the University of California. +Copyright (c) 2009 `PiCloud, Inc. `_. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the University of California, Berkeley nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +import operator +import os +import pickle +import struct +import sys +import types +from functools import partial +import itertools +from copy_reg import _extension_registry, _inverted_registry, _extension_cache +import new +import dis +import traceback + +#relevant opcodes +STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) +DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) +LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) +GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] + +HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) +EXTENDED_ARG = chr(dis.EXTENDED_ARG) + +import logging +cloudLog = logging.getLogger("Cloud.Transport") + +try: + import ctypes +except (MemoryError, ImportError): + logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True) + ctypes = None + PyObject_HEAD = None +else: + + # for reading internal structures + PyObject_HEAD = [ + ('ob_refcnt', ctypes.c_size_t), + ('ob_type', ctypes.c_void_p), + ] + + +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO + +# These helper functions were copied from PiCloud's util module. +def islambda(func): + return getattr(func,'func_name') == '' + +def xrange_params(xrangeobj): + """Returns a 3 element tuple describing the xrange start, step, and len + respectively + + Note: Only guarentees that elements of xrange are the same. parameters may + be different. + e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same + though w/ iteration + """ + + xrange_len = len(xrangeobj) + if not xrange_len: #empty + return (0,1,0) + start = xrangeobj[0] + if xrange_len == 1: #one element + return start, 1, 1 + return (start, xrangeobj[1] - xrangeobj[0], xrange_len) + +#debug variables intended for developer use: +printSerialization = False +printMemoization = False + +useForcedImports = True #Should I use forced imports for tracking? + + + +class CloudPickler(pickle.Pickler): + + dispatch = pickle.Pickler.dispatch.copy() + savedForceImports = False + savedDjangoEnv = False #hack tro transport django environment + + def __init__(self, file, protocol=None, min_size_to_save= 0): + pickle.Pickler.__init__(self,file,protocol) + self.modules = set() #set of modules needed to depickle + self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env + + def dump(self, obj): + # note: not thread safe + # minimal side-effects, so not fixing + recurse_limit = 3000 + base_recurse = sys.getrecursionlimit() + if base_recurse < recurse_limit: + sys.setrecursionlimit(recurse_limit) + self.inject_addons() + try: + return pickle.Pickler.dump(self, obj) + except RuntimeError, e: + if 'recursion' in e.args[0]: + msg = """Could not pickle object as excessively deep recursion required. + Try _fast_serialization=2 or contact PiCloud support""" + raise pickle.PicklingError(msg) + finally: + new_recurse = sys.getrecursionlimit() + if new_recurse == recurse_limit: + sys.setrecursionlimit(base_recurse) + + def save_buffer(self, obj): + """Fallback to save_string""" + pickle.Pickler.save_string(self,str(obj)) + dispatch[buffer] = save_buffer + + #block broken objects + def save_unsupported(self, obj, pack=None): + raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) + dispatch[types.GeneratorType] = save_unsupported + + #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it + try: + slice(0,1).__reduce__() + except TypeError: #can't pickle - + dispatch[slice] = save_unsupported + + #itertools objects do not pickle! + for v in itertools.__dict__.values(): + if type(v) is type: + dispatch[v] = save_unsupported + + + def save_dict(self, obj): + """hack fix + If the dict is a global, deal with it in a special way + """ + #print 'saving', obj + if obj is __builtins__: + self.save_reduce(_get_module_builtins, (), obj=obj) + else: + pickle.Pickler.save_dict(self, obj) + dispatch[pickle.DictionaryType] = save_dict + + + def save_module(self, obj, pack=struct.pack): + """ + Save a module as an import + """ + #print 'try save import', obj.__name__ + self.modules.add(obj) + self.save_reduce(subimport,(obj.__name__,), obj=obj) + dispatch[types.ModuleType] = save_module #new type + + def save_codeobject(self, obj, pack=struct.pack): + """ + Save a code object + """ + #print 'try to save codeobj: ', obj + args = ( + obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, + obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, + obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars + ) + self.save_reduce(types.CodeType, args, obj=obj) + dispatch[types.CodeType] = save_codeobject #new type + + def save_function(self, obj, name=None, pack=struct.pack): + """ Registered with the dispatch to handle all function types. + + Determines what kind of function obj is (e.g. lambda, defined at + interactive prompt, etc) and handles the pickling appropriately. + """ + write = self.write + + name = obj.__name__ + modname = pickle.whichmodule(obj, name) + #print 'which gives %s %s %s' % (modname, obj, name) + try: + themodule = sys.modules[modname] + except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__ + modname = '__main__' + + if modname == '__main__': + themodule = None + + if themodule: + self.modules.add(themodule) + + if not self.savedDjangoEnv: + #hack for django - if we detect the settings module, we transport it + django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '') + if django_settings: + django_mod = sys.modules.get(django_settings) + if django_mod: + cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name) + self.savedDjangoEnv = True + self.modules.add(django_mod) + write(pickle.MARK) + self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod) + write(pickle.POP_MARK) + + + # if func is lambda, def'ed at prompt, is in main, or is nested, then + # we'll pickle the actual function object rather than simply saving a + # reference (as is done in default pickler), via save_function_tuple. + if islambda(obj) or obj.func_code.co_filename == '' or themodule == None: + #Force server to import modules that have been imported in main + modList = None + if themodule == None and not self.savedForceImports: + mainmod = sys.modules['__main__'] + if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): + modList = list(mainmod.___pyc_forcedImports__) + self.savedForceImports = True + self.save_function_tuple(obj, modList) + return + else: # func is nested + klass = getattr(themodule, name, None) + if klass is None or klass is not obj: + self.save_function_tuple(obj, [themodule]) + return + + if obj.__dict__: + # essentially save_reduce, but workaround needed to avoid recursion + self.save(_restore_attr) + write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n') + self.memoize(obj) + self.save(obj.__dict__) + write(pickle.TUPLE + pickle.REDUCE) + else: + write(pickle.GLOBAL + modname + '\n' + name + '\n') + self.memoize(obj) + dispatch[types.FunctionType] = save_function + + def save_function_tuple(self, func, forced_imports): + """ Pickles an actual func object. + + A func comprises: code, globals, defaults, closure, and dict. We + extract and save these, injecting reducing functions at certain points + to recreate the func object. Keep in mind that some of these pieces + can contain a ref to the func itself. Thus, a naive save on these + pieces could trigger an infinite loop of save's. To get around that, + we first create a skeleton func object using just the code (this is + safe, since this won't contain a ref to the func), and memoize it as + soon as it's created. The other stuff can then be filled in later. + """ + save = self.save + write = self.write + + # save the modules (if any) + if forced_imports: + write(pickle.MARK) + save(_modules_to_main) + #print 'forced imports are', forced_imports + + forced_names = map(lambda m: m.__name__, forced_imports) + save((forced_names,)) + + #save((forced_imports,)) + write(pickle.REDUCE) + write(pickle.POP_MARK) + + code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) + + save(_fill_function) # skeleton function updater + write(pickle.MARK) # beginning of tuple that _fill_function expects + + # create a skeleton function object and memoize it + save(_make_skel_func) + save((code, len(closure), base_globals)) + write(pickle.REDUCE) + self.memoize(func) + + # save the rest of the func data needed by _fill_function + save(f_globals) + save(defaults) + save(closure) + save(dct) + write(pickle.TUPLE) + write(pickle.REDUCE) # applies _fill_function on the tuple + + @staticmethod + def extract_code_globals(co): + """ + Find all globals names read or written to by codeblock co + """ + code = co.co_code + names = co.co_names + out_names = set() + + n = len(code) + i = 0 + extended_arg = 0 + while i < n: + op = code[i] + + i = i+1 + if op >= HAVE_ARGUMENT: + oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg + extended_arg = 0 + i = i+2 + if op == EXTENDED_ARG: + extended_arg = oparg*65536L + if op in GLOBAL_OPS: + out_names.add(names[oparg]) + #print 'extracted', out_names, ' from ', names + return out_names + + def extract_func_data(self, func): + """ + Turn the function into a tuple of data necessary to recreate it: + code, globals, defaults, closure, dict + """ + code = func.func_code + + # extract all global ref's + func_global_refs = CloudPickler.extract_code_globals(code) + if code.co_consts: # see if nested function have any global refs + for const in code.co_consts: + if type(const) is types.CodeType and const.co_names: + func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const)) + # process all variables referenced by global environment + f_globals = {} + for var in func_global_refs: + #Some names, such as class functions are not global - we don't need them + if func.func_globals.has_key(var): + f_globals[var] = func.func_globals[var] + + # defaults requires no processing + defaults = func.func_defaults + + def get_contents(cell): + try: + return cell.cell_contents + except ValueError, e: #cell is empty error on not yet assigned + raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope') + + + # process closure + if func.func_closure: + closure = map(get_contents, func.func_closure) + else: + closure = [] + + # save the dict + dct = func.func_dict + + if printSerialization: + outvars = ['code: ' + str(code) ] + outvars.append('globals: ' + str(f_globals)) + outvars.append('defaults: ' + str(defaults)) + outvars.append('closure: ' + str(closure)) + print 'function ', func, 'is extracted to: ', ', '.join(outvars) + + base_globals = self.globals_ref.get(id(func.func_globals), {}) + self.globals_ref[id(func.func_globals)] = base_globals + + return (code, f_globals, defaults, closure, dct, base_globals) + + def save_global(self, obj, name=None, pack=struct.pack): + write = self.write + memo = self.memo + + if name is None: + name = obj.__name__ + + modname = getattr(obj, "__module__", None) + if modname is None: + modname = pickle.whichmodule(obj, name) + + try: + __import__(modname) + themodule = sys.modules[modname] + except (ImportError, KeyError, AttributeError): #should never occur + raise pickle.PicklingError( + "Can't pickle %r: Module %s cannot be found" % + (obj, modname)) + + if modname == '__main__': + themodule = None + + if themodule: + self.modules.add(themodule) + + sendRef = True + typ = type(obj) + #print 'saving', obj, typ + try: + try: #Deal with case when getattribute fails with exceptions + klass = getattr(themodule, name) + except (AttributeError): + if modname == '__builtin__': #new.* are misrepeported + modname = 'new' + __import__(modname) + themodule = sys.modules[modname] + try: + klass = getattr(themodule, name) + except AttributeError, a: + #print themodule, name, obj, type(obj) + raise pickle.PicklingError("Can't pickle builtin %s" % obj) + else: + raise + + except (ImportError, KeyError, AttributeError): + if typ == types.TypeType or typ == types.ClassType: + sendRef = False + else: #we can't deal with this + raise + else: + if klass is not obj and (typ == types.TypeType or typ == types.ClassType): + sendRef = False + if not sendRef: + #note: Third party types might crash this - add better checks! + d = dict(obj.__dict__) #copy dict proxy to a dict + if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties + d.pop('__dict__',None) + d.pop('__weakref__',None) + + # hack as __new__ is stored differently in the __dict__ + new_override = d.get('__new__', None) + if new_override: + d['__new__'] = obj.__new__ + + self.save_reduce(type(obj),(obj.__name__,obj.__bases__, + d),obj=obj) + #print 'internal reduce dask %s %s' % (obj, d) + return + + if self.proto >= 2: + code = _extension_registry.get((modname, name)) + if code: + assert code > 0 + if code <= 0xff: + write(pickle.EXT1 + chr(code)) + elif code <= 0xffff: + write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8)) + else: + write(pickle.EXT4 + pack("= 2 and getattr(func, "__name__", "") == "__newobj__": + #Added fix to allow transient + cls = args[0] + if not hasattr(cls, "__new__"): + raise pickle.PicklingError( + "args[0] from __newobj__ args has no __new__") + if obj is not None and cls is not obj.__class__: + raise pickle.PicklingError( + "args[0] from __newobj__ args has the wrong class") + args = args[1:] + save(cls) + + #Don't pickle transient entries + if hasattr(obj, '__transient__'): + transient = obj.__transient__ + state = state.copy() + + for k in list(state.keys()): + if k in transient: + del state[k] + + save(args) + write(pickle.NEWOBJ) + else: + save(func) + save(args) + write(pickle.REDUCE) + + if obj is not None: + self.memoize(obj) + + # More new special cases (that work with older protocols as + # well): when __reduce__ returns a tuple with 4 or 5 items, + # the 4th and 5th item should be iterators that provide list + # items and dict items (as (key, value) tuples), or None. + + if listitems is not None: + self._batch_appends(listitems) + + if dictitems is not None: + self._batch_setitems(dictitems) + + if state is not None: + #print 'obj %s has state %s' % (obj, state) + save(state) + write(pickle.BUILD) + + + def save_xrange(self, obj): + """Save an xrange object in python 2.5 + Python 2.6 supports this natively + """ + range_params = xrange_params(obj) + self.save_reduce(_build_xrange,range_params) + + #python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it + try: + xrange(0).__reduce__() + except TypeError: #can't pickle -- use PiCloud pickler + dispatch[xrange] = save_xrange + + def save_partial(self, obj): + """Partial objects do not serialize correctly in python2.x -- this fixes the bugs""" + self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords)) + + if sys.version_info < (2,7): #2.7 supports partial pickling + dispatch[partial] = save_partial + + + def save_file(self, obj): + """Save a file""" + import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute + from ..transport.adapter import SerializingAdapter + + if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): + raise pickle.PicklingError("Cannot pickle files that do not map to an actual file") + if obj.name == '': + return self.save_reduce(getattr, (sys,'stdout'), obj=obj) + if obj.name == '': + return self.save_reduce(getattr, (sys,'stderr'), obj=obj) + if obj.name == '': + raise pickle.PicklingError("Cannot pickle standard input") + if hasattr(obj, 'isatty') and obj.isatty(): + raise pickle.PicklingError("Cannot pickle files that map to tty objects") + if 'r' not in obj.mode: + raise pickle.PicklingError("Cannot pickle files that are not opened for reading") + name = obj.name + try: + fsize = os.stat(name).st_size + except OSError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name) + + if obj.closed: + #create an empty closed string io + retval = pystringIO.StringIO("") + retval.close() + elif not fsize: #empty file + retval = pystringIO.StringIO("") + try: + tmpfile = file(name) + tst = tmpfile.read(1) + except IOError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) + tmpfile.close() + if tst != '': + raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name) + elif fsize > SerializingAdapter.max_transmit_data: + raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" % + (name,SerializingAdapter.max_transmit_data)) + else: + try: + tmpfile = file(name) + contents = tmpfile.read(SerializingAdapter.max_transmit_data) + tmpfile.close() + except IOError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) + retval = pystringIO.StringIO(contents) + curloc = obj.tell() + retval.seek(curloc) + + retval.name = name + self.save(retval) #save stringIO + self.memoize(obj) + + dispatch[file] = save_file + """Special functions for Add-on libraries""" + + def inject_numpy(self): + numpy = sys.modules.get('numpy') + if not numpy or not hasattr(numpy, 'ufunc'): + return + self.dispatch[numpy.ufunc] = self.__class__.save_ufunc + + numpy_tst_mods = ['numpy', 'scipy.special'] + def save_ufunc(self, obj): + """Hack function for saving numpy ufunc objects""" + name = obj.__name__ + for tst_mod_name in self.numpy_tst_mods: + tst_mod = sys.modules.get(tst_mod_name, None) + if tst_mod: + if name in tst_mod.__dict__: + self.save_reduce(_getobject, (tst_mod_name, name)) + return + raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj)) + + def inject_timeseries(self): + """Handle bugs with pickling scikits timeseries""" + tseries = sys.modules.get('scikits.timeseries.tseries') + if not tseries or not hasattr(tseries, 'Timeseries'): + return + self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries + + def save_timeseries(self, obj): + import scikits.timeseries.tseries as ts + + func, reduce_args, state = obj.__reduce__() + if func != ts._tsreconstruct: + raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func)) + state = (1, + obj.shape, + obj.dtype, + obj.flags.fnc, + obj._data.tostring(), + ts.getmaskarray(obj).tostring(), + obj._fill_value, + obj._dates.shape, + obj._dates.__array__().tostring(), + obj._dates.dtype, #added -- preserve type + obj.freq, + obj._optinfo, + ) + return self.save_reduce(_genTimeSeries, (reduce_args, state)) + + def inject_email(self): + """Block email LazyImporters from being saved""" + email = sys.modules.get('email') + if not email: + return + self.dispatch[email.LazyImporter] = self.__class__.save_unsupported + + def inject_addons(self): + """Plug in system. Register additional pickling functions if modules already loaded""" + self.inject_numpy() + self.inject_timeseries() + self.inject_email() + + """Python Imaging Library""" + def save_image(self, obj): + if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \ + and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()): + #if image not loaded yet -- lazy load + self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj) + else: + #image is loaded - just transmit it over + self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj) + + """ + def memoize(self, obj): + pickle.Pickler.memoize(self, obj) + if printMemoization: + print 'memoizing ' + str(obj) + """ + + + +# Shorthands for legacy support + +def dump(obj, file, protocol=2): + CloudPickler(file, protocol).dump(obj) + +def dumps(obj, protocol=2): + file = StringIO() + + cp = CloudPickler(file,protocol) + cp.dump(obj) + + #print 'cloud dumped', str(obj), str(cp.modules) + + return file.getvalue() + + +#hack for __import__ not working as desired +def subimport(name): + __import__(name) + return sys.modules[name] + +#hack to load django settings: +def django_settings_load(name): + modified_env = False + + if 'DJANGO_SETTINGS_MODULE' not in os.environ: + os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps + modified_env = True + try: + module = subimport(name) + except Exception, i: + print >> sys.stderr, 'Cloud not import django settings %s:' % (name) + print_exec(sys.stderr) + if modified_env: + del os.environ['DJANGO_SETTINGS_MODULE'] + else: + #add project directory to sys,path: + if hasattr(module,'__file__'): + dirname = os.path.split(module.__file__)[0] + '/' + sys.path.append(dirname) + +# restores function attributes +def _restore_attr(obj, attr): + for key, val in attr.items(): + setattr(obj, key, val) + return obj + +def _get_module_builtins(): + return pickle.__builtins__ + +def print_exec(stream): + ei = sys.exc_info() + traceback.print_exception(ei[0], ei[1], ei[2], None, stream) + +def _modules_to_main(modList): + """Force every module in modList to be placed into main""" + if not modList: + return + + main = sys.modules['__main__'] + for modname in modList: + if type(modname) is str: + try: + mod = __import__(modname) + except Exception, i: #catch all... + sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \ +A version mismatch is likely. Specific error was:\n' % modname) + print_exec(sys.stderr) + else: + setattr(main,mod.__name__, mod) + else: + #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD) + #In old version actual module was sent + setattr(main,modname.__name__, modname) + +#object generators: +def _build_xrange(start, step, len): + """Built xrange explicitly""" + return xrange(start, start + step*len, step) + +def _genpartial(func, args, kwds): + if not args: + args = () + if not kwds: + kwds = {} + return partial(func, *args, **kwds) + + +def _fill_function(func, globals, defaults, closure, dict): + """ Fills in the rest of function data into the skeleton function object + that were created via _make_skel_func(). + """ + func.func_globals.update(globals) + func.func_defaults = defaults + func.func_dict = dict + + if len(closure) != len(func.func_closure): + raise pickle.UnpicklingError("closure lengths don't match up") + for i in range(len(closure)): + _change_cell_value(func.func_closure[i], closure[i]) + + return func + +def _make_skel_func(code, num_closures, base_globals = None): + """ Creates a skeleton function object that contains just the provided + code and the correct number of cells in func_closure. All other + func attributes (e.g. func_globals) are empty. + """ + #build closure (cells): + if not ctypes: + raise Exception('ctypes failed to import; cannot build function') + + cellnew = ctypes.pythonapi.PyCell_New + cellnew.restype = ctypes.py_object + cellnew.argtypes = (ctypes.py_object,) + dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures))) + + if base_globals is None: + base_globals = {} + base_globals['__builtins__'] = __builtins__ + + return types.FunctionType(code, base_globals, + None, None, dummy_closure) + +# this piece of opaque code is needed below to modify 'cell' contents +cell_changer_code = new.code( + 1, 1, 2, 0, + ''.join([ + chr(dis.opmap['LOAD_FAST']), '\x00\x00', + chr(dis.opmap['DUP_TOP']), + chr(dis.opmap['STORE_DEREF']), '\x00\x00', + chr(dis.opmap['RETURN_VALUE']) + ]), + (), (), ('newval',), '', 'cell_changer', 1, '', ('c',), () +) + +def _change_cell_value(cell, newval): + """ Changes the contents of 'cell' object to newval """ + return new.function(cell_changer_code, {}, None, (), (cell,))(newval) + +"""Constructors for 3rd party libraries +Note: These can never be renamed due to client compatibility issues""" + +def _getobject(modname, attribute): + mod = __import__(modname) + return mod.__dict__[attribute] + +def _generateImage(size, mode, str_rep): + """Generate image from string representation""" + import Image + i = Image.new(mode, size) + i.fromstring(str_rep) + return i + +def _lazyloadImage(fp): + import Image + fp.seek(0) #works in almost any case + return Image.open(fp) + +"""Timeseries""" +def _genTimeSeries(reduce_args, state): + import scikits.timeseries.tseries as ts + from numpy import ndarray + from numpy.ma import MaskedArray + + + time_series = ts._tsreconstruct(*reduce_args) + + #from setstate modified + (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state + #print 'regenerating %s' % dtyp + + MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv)) + _dates = time_series._dates + #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ + ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm)) + _dates.freq = frq + _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None, + toobj=None, toord=None, tostr=None)) + # Update the _optinfo dictionary + time_series._optinfo.update(infodict) + return time_series + diff --git a/python/pyspark/context.py b/python/pyspark/context.py new file mode 100644 index 0000000000..6172d69dcf --- /dev/null +++ b/python/pyspark/context.py @@ -0,0 +1,158 @@ +import os +import atexit +from tempfile import NamedTemporaryFile + +from pyspark.broadcast import Broadcast +from pyspark.java_gateway import launch_gateway +from pyspark.serializers import dump_pickle, write_with_length, batched +from pyspark.rdd import RDD + +from py4j.java_collections import ListConverter + + +class SparkContext(object): + """ + Main entry point for Spark functionality. A SparkContext represents the + connection to a Spark cluster, and can be used to create L{RDD}s and + broadcast variables on that cluster. + """ + + gateway = launch_gateway() + jvm = gateway.jvm + _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile + _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile + + def __init__(self, master, jobName, sparkHome=None, pyFiles=None, + environment=None, batchSize=1024): + """ + Create a new SparkContext. + + @param master: Cluster URL to connect to + (e.g. mesos://host:port, spark://host:port, local[4]). + @param jobName: A name for your job, to display on the cluster web UI + @param sparkHome: Location where Spark is installed on cluster nodes. + @param pyFiles: Collection of .zip or .py files to send to the cluster + and add to PYTHONPATH. These can be paths on the local file + system or HDFS, HTTP, HTTPS, or FTP URLs. + @param environment: A dictionary of environment variables to set on + worker nodes. + @param batchSize: The number of Python objects represented as a single + Java object. Set 1 to disable batching or -1 to use an + unlimited batch size. + """ + self.master = master + self.jobName = jobName + self.sparkHome = sparkHome or None # None becomes null in Py4J + self.environment = environment or {} + self.batchSize = batchSize # -1 represents a unlimited batch size + + # Create the Java SparkContext through Py4J + empty_string_array = self.gateway.new_array(self.jvm.String, 0) + self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, + empty_string_array) + + self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') + # Broadcast's __reduce__ method stores Broadcast instances here. + # This allows other code to determine which Broadcast instances have + # been pickled, so it can determine which Java broadcast objects to + # send. + self._pickled_broadcast_vars = set() + + # Deploy any code dependencies specified in the constructor + for path in (pyFiles or []): + self.addPyFile(path) + + @property + def defaultParallelism(self): + """ + Default level of parallelism to use when not given by user (e.g. for + reduce tasks) + """ + return self._jsc.sc().defaultParallelism() + + def __del__(self): + if self._jsc: + self._jsc.stop() + + def stop(self): + """ + Shut down the SparkContext. + """ + self._jsc.stop() + self._jsc = None + + def parallelize(self, c, numSlices=None): + """ + Distribute a local Python collection to form an RDD. + """ + numSlices = numSlices or self.defaultParallelism + # Calling the Java parallelize() method with an ArrayList is too slow, + # because it sends O(n) Py4J commands. As an alternative, serialized + # objects are written to a file and loaded through textFile(). + tempFile = NamedTemporaryFile(delete=False) + atexit.register(lambda: os.unlink(tempFile.name)) + if self.batchSize != 1: + c = batched(c, self.batchSize) + for x in c: + write_with_length(dump_pickle(x), tempFile) + tempFile.close() + jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) + return RDD(jrdd, self) + + def textFile(self, name, minSplits=None): + """ + Read a text file from HDFS, a local file system (available on all + nodes), or any Hadoop-supported file system URI, and return it as an + RDD of Strings. + """ + minSplits = minSplits or min(self.defaultParallelism, 2) + jrdd = self._jsc.textFile(name, minSplits) + return RDD(jrdd, self) + + def union(self, rdds): + """ + Build the union of a list of RDDs. + """ + first = rdds[0]._jrdd + rest = [x._jrdd for x in rdds[1:]] + rest = ListConverter().convert(rest, self.gateway._gateway_client) + return RDD(self._jsc.union(first, rest), self) + + def broadcast(self, value): + """ + Broadcast a read-only variable to the cluster, returning a C{Broadcast} + object for reading it in distributed functions. The variable will be + sent to each cluster only once. + """ + jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) + return Broadcast(jbroadcast.id(), value, jbroadcast, + self._pickled_broadcast_vars) + + def addFile(self, path): + """ + Add a file to be downloaded into the working directory of this Spark + job on every node. The C{path} passed can be either a local file, + a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, + HTTPS or FTP URI. + """ + self._jsc.sc().addFile(path) + + def clearFiles(self): + """ + Clear the job's list of files added by L{addFile} or L{addPyFile} so + that they do not get downloaded to any new nodes. + """ + # TODO: remove added .py or .zip files from the PYTHONPATH? + self._jsc.sc().clearFiles() + + def addPyFile(self, path): + """ + Add a .py or .zip dependency for all tasks to be executed on this + SparkContext in the future. The C{path} passed can be either a local + file, a file in HDFS (or other Hadoop-supported filesystems), or an + HTTP, HTTPS or FTP URI. + """ + self.addFile(path) + filename = path.split("/")[-1] + os.environ["PYTHONPATH"] = \ + "%s:%s" % (filename, os.environ["PYTHONPATH"]) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py new file mode 100644 index 0000000000..2329e536cc --- /dev/null +++ b/python/pyspark/java_gateway.py @@ -0,0 +1,38 @@ +import os +import sys +from subprocess import Popen, PIPE +from threading import Thread +from py4j.java_gateway import java_import, JavaGateway, GatewayClient + + +SPARK_HOME = os.environ["SPARK_HOME"] + + +def launch_gateway(): + # Launch the Py4j gateway using Spark's run command so that we pick up the + # proper classpath and SPARK_MEM settings from spark-env.sh + command = [os.path.join(SPARK_HOME, "run"), "py4j.GatewayServer", + "--die-on-broken-pipe", "0"] + proc = Popen(command, stdout=PIPE, stdin=PIPE) + # Determine which ephemeral port the server started on: + port = int(proc.stdout.readline()) + # Create a thread to echo output from the GatewayServer, which is required + # for Java log output to show up: + class EchoOutputThread(Thread): + def __init__(self, stream): + Thread.__init__(self) + self.daemon = True + self.stream = stream + + def run(self): + while True: + line = self.stream.readline() + sys.stderr.write(line) + EchoOutputThread(proc.stdout).start() + # Connect to the gateway + gateway = JavaGateway(GatewayClient(port=port), auto_convert=False) + # Import the classes used by PySpark + java_import(gateway.jvm, "spark.api.java.*") + java_import(gateway.jvm, "spark.api.python.*") + java_import(gateway.jvm, "scala.Tuple2") + return gateway diff --git a/python/pyspark/join.py b/python/pyspark/join.py new file mode 100644 index 0000000000..7036c47980 --- /dev/null +++ b/python/pyspark/join.py @@ -0,0 +1,92 @@ +""" +Copyright (c) 2011, Douban Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +def _do_python_join(rdd, other, numSplits, dispatch): + vs = rdd.map(lambda (k, v): (k, (1, v))) + ws = other.map(lambda (k, v): (k, (2, v))) + return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch) + + +def python_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return [(v, w) for v in vbuf for w in wbuf] + return _do_python_join(rdd, other, numSplits, dispatch) + + +def python_right_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not vbuf: + vbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + return _do_python_join(rdd, other, numSplits, dispatch) + + +def python_left_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not wbuf: + wbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + return _do_python_join(rdd, other, numSplits, dispatch) + + +def python_cogroup(rdd, other, numSplits): + vs = rdd.map(lambda (k, v): (k, (1, v))) + ws = other.map(lambda (k, v): (k, (2, v))) + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return (vbuf, wbuf) + return vs.union(ws).groupByKey(numSplits).mapValues(dispatch) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py new file mode 100644 index 0000000000..cbffb6cc1f --- /dev/null +++ b/python/pyspark/rdd.py @@ -0,0 +1,713 @@ +import atexit +from base64 import standard_b64encode as b64enc +import copy +from collections import defaultdict +from itertools import chain, ifilter, imap, product +import operator +import os +import shlex +from subprocess import Popen, PIPE +from tempfile import NamedTemporaryFile +from threading import Thread + +from pyspark import cloudpickle +from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ + read_from_pickle_file +from pyspark.join import python_join, python_left_outer_join, \ + python_right_outer_join, python_cogroup + +from py4j.java_collections import ListConverter, MapConverter + + +__all__ = ["RDD"] + + +class RDD(object): + """ + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + Represents an immutable, partitioned collection of elements that can be + operated on in parallel. + """ + + def __init__(self, jrdd, ctx): + self._jrdd = jrdd + self.is_cached = False + self.ctx = ctx + + @property + def context(self): + """ + The L{SparkContext} that this RDD was created on. + """ + return self.ctx + + def cache(self): + """ + Persist this RDD with the default storage level (C{MEMORY_ONLY}). + """ + self.is_cached = True + self._jrdd.cache() + return self + + # TODO persist(self, storageLevel) + + def map(self, f, preservesPartitioning=False): + """ + Return a new RDD containing the distinct elements in this RDD. + """ + def func(iterator): return imap(f, iterator) + return PipelinedRDD(self, func, preservesPartitioning) + + def flatMap(self, f, preservesPartitioning=False): + """ + Return a new RDD by first applying a function to all elements of this + RDD, and then flattening the results. + + >>> rdd = sc.parallelize([2, 3, 4]) + >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) + [1, 1, 1, 2, 2, 3] + >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) + [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] + """ + def func(iterator): return chain.from_iterable(imap(f, iterator)) + return self.mapPartitions(func, preservesPartitioning) + + def mapPartitions(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition of this RDD. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) + >>> def f(iterator): yield sum(iterator) + >>> rdd.mapPartitions(f).collect() + [3, 7] + """ + return PipelinedRDD(self, f, preservesPartitioning) + + # TODO: mapPartitionsWithSplit + + def filter(self, f): + """ + Return a new RDD containing only the elements that satisfy a predicate. + + >>> rdd = sc.parallelize([1, 2, 3, 4, 5]) + >>> rdd.filter(lambda x: x % 2 == 0).collect() + [2, 4] + """ + def func(iterator): return ifilter(f, iterator) + return self.mapPartitions(func) + + def distinct(self): + """ + Return a new RDD containing the distinct elements in this RDD. + + >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) + [1, 2, 3] + """ + return self.map(lambda x: (x, "")) \ + .reduceByKey(lambda x, _: x) \ + .map(lambda (x, _): x) + + # TODO: sampling needs to be re-implemented due to Batch + #def sample(self, withReplacement, fraction, seed): + # jrdd = self._jrdd.sample(withReplacement, fraction, seed) + # return RDD(jrdd, self.ctx) + + #def takeSample(self, withReplacement, num, seed): + # vals = self._jrdd.takeSample(withReplacement, num, seed) + # return [load_pickle(bytes(x)) for x in vals] + + def union(self, other): + """ + Return the union of this RDD and another one. + + >>> rdd = sc.parallelize([1, 1, 2, 3]) + >>> rdd.union(rdd).collect() + [1, 1, 2, 3, 1, 1, 2, 3] + """ + return RDD(self._jrdd.union(other._jrdd), self.ctx) + + def __add__(self, other): + """ + Return the union of this RDD and another one. + + >>> rdd = sc.parallelize([1, 1, 2, 3]) + >>> (rdd + rdd).collect() + [1, 1, 2, 3, 1, 1, 2, 3] + """ + if not isinstance(other, RDD): + raise TypeError + return self.union(other) + + # TODO: sort + + def glom(self): + """ + Return an RDD created by coalescing all elements within each partition + into a list. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) + >>> sorted(rdd.glom().collect()) + [[1, 2], [3, 4]] + """ + def func(iterator): yield list(iterator) + return self.mapPartitions(func) + + def cartesian(self, other): + """ + Return the Cartesian product of this RDD and another one, that is, the + RDD of all pairs of elements C{(a, b)} where C{a} is in C{self} and + C{b} is in C{other}. + + >>> rdd = sc.parallelize([1, 2]) + >>> sorted(rdd.cartesian(rdd).collect()) + [(1, 1), (1, 2), (2, 1), (2, 2)] + """ + # Due to batching, we can't use the Java cartesian method. + java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) + def unpack_batches(pair): + (x, y) = pair + if type(x) == Batch or type(y) == Batch: + xs = x.items if type(x) == Batch else [x] + ys = y.items if type(y) == Batch else [y] + for pair in product(xs, ys): + yield pair + else: + yield pair + return java_cartesian.flatMap(unpack_batches) + + def groupBy(self, f, numSplits=None): + """ + Return an RDD of grouped items. + + >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) + >>> result = rdd.groupBy(lambda x: x % 2).collect() + >>> sorted([(x, sorted(y)) for (x, y) in result]) + [(0, [2, 8]), (1, [1, 1, 3, 5])] + """ + return self.map(lambda x: (f(x), x)).groupByKey(numSplits) + + def pipe(self, command, env={}): + """ + Return an RDD created by piping elements to a forked external process. + + >>> sc.parallelize([1, 2, 3]).pipe('cat').collect() + ['1', '2', '3'] + """ + def func(iterator): + pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) + def pipe_objs(out): + for obj in iterator: + out.write(str(obj).rstrip('\n') + '\n') + out.close() + Thread(target=pipe_objs, args=[pipe.stdin]).start() + return (x.rstrip('\n') for x in pipe.stdout) + return self.mapPartitions(func) + + def foreach(self, f): + """ + Applies a function to all elements of this RDD. + + >>> def f(x): print x + >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) + """ + self.map(f).collect() # Force evaluation + + def collect(self): + """ + Return a list that contains all of the elements in this RDD. + """ + picklesInJava = self._jrdd.collect().iterator() + return list(self._collect_iterator_through_file(picklesInJava)) + + def _collect_iterator_through_file(self, iterator): + # Transferring lots of data through Py4J can be slow because + # socket.readline() is inefficient. Instead, we'll dump the data to a + # file and read it back. + tempFile = NamedTemporaryFile(delete=False) + tempFile.close() + def clean_up_file(): + try: os.unlink(tempFile.name) + except: pass + atexit.register(clean_up_file) + self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + for item in read_from_pickle_file(tempFile): + yield item + os.unlink(tempFile.name) + + def reduce(self, f): + """ + Reduces the elements of this RDD using the specified associative binary + operator. + + >>> from operator import add + >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add) + 15 + >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) + 10 + """ + def func(iterator): + acc = None + for obj in iterator: + if acc is None: + acc = obj + else: + acc = f(obj, acc) + if acc is not None: + yield acc + vals = self.mapPartitions(func).collect() + return reduce(f, vals) + + def fold(self, zeroValue, op): + """ + Aggregate the elements of each partition, and then the results for all + the partitions, using a given associative function and a neutral "zero + value." + + The function C{op(t1, t2)} is allowed to modify C{t1} and return it + as its result value to avoid object allocation; however, it should not + modify C{t2}. + + >>> from operator import add + >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) + 15 + """ + def func(iterator): + acc = zeroValue + for obj in iterator: + acc = op(obj, acc) + yield acc + vals = self.mapPartitions(func).collect() + return reduce(op, vals, zeroValue) + + # TODO: aggregate + + def sum(self): + """ + Add up the elements in this RDD. + + >>> sc.parallelize([1.0, 2.0, 3.0]).sum() + 6.0 + """ + return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + + def count(self): + """ + Return the number of elements in this RDD. + + >>> sc.parallelize([2, 3, 4]).count() + 3 + """ + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() + + def countByValue(self): + """ + Return the count of each unique value in this RDD as a dictionary of + (value, count) pairs. + + >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items()) + [(1, 2), (2, 3)] + """ + def countPartition(iterator): + counts = defaultdict(int) + for obj in iterator: + counts[obj] += 1 + yield counts + def mergeMaps(m1, m2): + for (k, v) in m2.iteritems(): + m1[k] += v + return m1 + return self.mapPartitions(countPartition).reduce(mergeMaps) + + def take(self, num): + """ + Take the first num elements of the RDD. + + This currently scans the partitions *one by one*, so it will be slow if + a lot of partitions are required. In that case, use L{collect} to get + the whole RDD instead. + + >>> sc.parallelize([2, 3, 4, 5, 6]).take(2) + [2, 3] + >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) + [2, 3, 4, 5, 6] + """ + items = [] + splits = self._jrdd.splits() + taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) + while len(items) < num and splits: + split = splits.pop(0) + iterator = self._jrdd.iterator(split, taskContext) + items.extend(self._collect_iterator_through_file(iterator)) + return items[:num] + + def first(self): + """ + Return the first element in this RDD. + + >>> sc.parallelize([2, 3, 4]).first() + 2 + """ + return self.take(1)[0] + + def saveAsTextFile(self, path): + """ + Save this RDD as a text file, using string representations of elements. + + >>> tempFile = NamedTemporaryFile(delete=True) + >>> tempFile.close() + >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name) + >>> from fileinput import input + >>> from glob import glob + >>> ''.join(input(glob(tempFile.name + "/part-0000*"))) + '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' + """ + def func(iterator): + return (str(x).encode("utf-8") for x in iterator) + keyed = PipelinedRDD(self, func) + keyed._bypass_serializer = True + keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path) + + # Pair functions + + def collectAsMap(self): + """ + Return the key-value pairs in this RDD to the master as a dictionary. + + >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() + >>> m[1] + 2 + >>> m[3] + 4 + """ + return dict(self.collect()) + + def reduceByKey(self, func, numSplits=None): + """ + Merge the values for each key using an associative reduce function. + + This will also perform the merging locally on each mapper before + sending results to a reducer, similarly to a "combiner" in MapReduce. + + Output will be hash-partitioned with C{numSplits} splits, or the + default parallelism level if C{numSplits} is not specified. + + >>> from operator import add + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.reduceByKey(add).collect()) + [('a', 2), ('b', 1)] + """ + return self.combineByKey(lambda x: x, func, func, numSplits) + + def reduceByKeyLocally(self, func): + """ + Merge the values for each key using an associative reduce function, but + return the results immediately to the master as a dictionary. + + This will also perform the merging locally on each mapper before + sending results to a reducer, similarly to a "combiner" in MapReduce. + + >>> from operator import add + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.reduceByKeyLocally(add).items()) + [('a', 2), ('b', 1)] + """ + def reducePartition(iterator): + m = {} + for (k, v) in iterator: + m[k] = v if k not in m else func(m[k], v) + yield m + def mergeMaps(m1, m2): + for (k, v) in m2.iteritems(): + m1[k] = v if k not in m1 else func(m1[k], v) + return m1 + return self.mapPartitions(reducePartition).reduce(mergeMaps) + + def countByKey(self): + """ + Count the number of elements for each key, and return the result to the + master as a dictionary. + + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.countByKey().items()) + [('a', 2), ('b', 1)] + """ + return self.map(lambda x: x[0]).countByValue() + + def join(self, other, numSplits=None): + """ + Return an RDD containing all pairs of elements with matching keys in + C{self} and C{other}. + + Each pair of elements will be returned as a (k, (v1, v2)) tuple, where + (k, v1) is in C{self} and (k, v2) is in C{other}. + + Performs a hash join across the cluster. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2), ("a", 3)]) + >>> sorted(x.join(y).collect()) + [('a', (1, 2)), ('a', (1, 3))] + """ + return python_join(self, other, numSplits) + + def leftOuterJoin(self, other, numSplits=None): + """ + Perform a left outer join of C{self} and C{other}. + + For each element (k, v) in C{self}, the resulting RDD will either + contain all pairs (k, (v, w)) for w in C{other}, or the pair + (k, (v, None)) if no elements in other have key k. + + Hash-partitions the resulting RDD into the given number of partitions. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> sorted(x.leftOuterJoin(y).collect()) + [('a', (1, 2)), ('b', (4, None))] + """ + return python_left_outer_join(self, other, numSplits) + + def rightOuterJoin(self, other, numSplits=None): + """ + Perform a right outer join of C{self} and C{other}. + + For each element (k, w) in C{other}, the resulting RDD will either + contain all pairs (k, (v, w)) for v in this, or the pair (k, (None, w)) + if no elements in C{self} have key k. + + Hash-partitions the resulting RDD into the given number of partitions. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> sorted(y.rightOuterJoin(x).collect()) + [('a', (2, 1)), ('b', (None, 4))] + """ + return python_right_outer_join(self, other, numSplits) + + # TODO: add option to control map-side combining + def partitionBy(self, numSplits, hashFunc=hash): + """ + Return a copy of the RDD partitioned using the specified partitioner. + + >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) + >>> sets = pairs.partitionBy(2).glom().collect() + >>> set(sets[0]).intersection(set(sets[1])) + set([]) + """ + if numSplits is None: + numSplits = self.ctx.defaultParallelism + # Transferring O(n) objects to Java is too expensive. Instead, we'll + # form the hash buckets in Python, transferring O(numSplits) objects + # to Java. Each object is a (splitNumber, [objects]) pair. + def add_shuffle_key(iterator): + buckets = defaultdict(list) + for (k, v) in iterator: + buckets[hashFunc(k) % numSplits].append((k, v)) + for (split, items) in buckets.iteritems(): + yield str(split) + yield dump_pickle(Batch(items)) + keyed = PipelinedRDD(self, add_shuffle_key) + keyed._bypass_serializer = True + pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() + partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) + jrdd = pairRDD.partitionBy(partitioner) + jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) + return RDD(jrdd, self.ctx) + + # TODO: add control over map-side aggregation + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numSplits=None): + """ + Generic function to combine the elements for each key using a custom + set of aggregation functions. + + Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined + type" C. Note that V and C can be different -- for example, one might + group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]). + + Users provide three functions: + + - C{createCombiner}, which turns a V into a C (e.g., creates + a one-element list) + - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of + a list) + - C{mergeCombiners}, to combine two C's into a single one. + + In addition, users can control the partitioning of the output RDD. + + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> def f(x): return x + >>> def add(a, b): return a + str(b) + >>> sorted(x.combineByKey(str, add, add).collect()) + [('a', '11'), ('b', '1')] + """ + if numSplits is None: + numSplits = self.ctx.defaultParallelism + def combineLocally(iterator): + combiners = {} + for (k, v) in iterator: + if k not in combiners: + combiners[k] = createCombiner(v) + else: + combiners[k] = mergeValue(combiners[k], v) + return combiners.iteritems() + locally_combined = self.mapPartitions(combineLocally) + shuffled = locally_combined.partitionBy(numSplits) + def _mergeCombiners(iterator): + combiners = {} + for (k, v) in iterator: + if not k in combiners: + combiners[k] = v + else: + combiners[k] = mergeCombiners(combiners[k], v) + return combiners.iteritems() + return shuffled.mapPartitions(_mergeCombiners) + + # TODO: support variant with custom partitioner + def groupByKey(self, numSplits=None): + """ + Group the values for each key in the RDD into a single sequence. + Hash-partitions the resulting RDD with into numSplits partitions. + + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(x.groupByKey().collect()) + [('a', [1, 1]), ('b', [1])] + """ + + def createCombiner(x): + return [x] + + def mergeValue(xs, x): + xs.append(x) + return xs + + def mergeCombiners(a, b): + return a + b + + return self.combineByKey(createCombiner, mergeValue, mergeCombiners, + numSplits) + + # TODO: add tests + def flatMapValues(self, f): + """ + Pass each value in the key-value pair RDD through a flatMap function + without changing the keys; this also retains the original RDD's + partitioning. + """ + flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + return self.flatMap(flat_map_fn, preservesPartitioning=True) + + def mapValues(self, f): + """ + Pass each value in the key-value pair RDD through a map function + without changing the keys; this also retains the original RDD's + partitioning. + """ + map_values_fn = lambda (k, v): (k, f(v)) + return self.map(map_values_fn, preservesPartitioning=True) + + # TODO: support varargs cogroup of several RDDs. + def groupWith(self, other): + """ + Alias for cogroup. + """ + return self.cogroup(other) + + # TODO: add variant with custom parittioner + def cogroup(self, other, numSplits=None): + """ + For each key k in C{self} or C{other}, return a resulting RDD that + contains a tuple with the list of values for that key in C{self} as well + as C{other}. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> sorted(x.cogroup(y).collect()) + [('a', ([1], [2])), ('b', ([4], []))] + """ + return python_cogroup(self, other, numSplits) + + # TODO: `lookup` is disabled because we can't make direct comparisons based + # on the key; we need to compare the hash of the key to the hash of the + # keys in the pairs. This could be an expensive operation, since those + # hashes aren't retained. + + +class PipelinedRDD(RDD): + """ + Pipelined maps: + >>> rdd = sc.parallelize([1, 2, 3, 4]) + >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + + Pipelined reduces: + >>> from operator import add + >>> rdd.map(lambda x: 2 * x).reduce(add) + 20 + >>> rdd.flatMap(lambda x: [x, x]).reduce(add) + 20 + """ + def __init__(self, prev, func, preservesPartitioning=False): + if isinstance(prev, PipelinedRDD) and not prev.is_cached: + prev_func = prev.func + def pipeline_func(iterator): + return func(prev_func(iterator)) + self.func = pipeline_func + self.preservesPartitioning = \ + prev.preservesPartitioning and preservesPartitioning + self._prev_jrdd = prev._prev_jrdd + else: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self.is_cached = False + self.ctx = prev.ctx + self.prev = prev + self._jrdd_val = None + self._bypass_serializer = False + + @property + def _jrdd(self): + if self._jrdd_val: + return self._jrdd_val + func = self.func + if not self._bypass_serializer and self.ctx.batchSize != 1: + oldfunc = self.func + batchSize = self.ctx.batchSize + def batched_func(iterator): + return batched(oldfunc(iterator), batchSize) + func = batched_func + cmds = [func, self._bypass_serializer] + pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], + self.ctx.gateway._gateway_client) + self.ctx._pickled_broadcast_vars.clear() + class_manifest = self._prev_jrdd.classManifest() + env = copy.copy(self.ctx.environment) + env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "") + env = MapConverter().convert(env, self.ctx.gateway._gateway_client) + python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), + pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, + broadcast_vars, class_manifest) + self._jrdd_val = python_rdd.asJavaRDD() + return self._jrdd_val + + +def _test(): + import doctest + from pyspark.context import SparkContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + doctest.testmod(globs=globs) + globs['sc'].stop() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py new file mode 100644 index 0000000000..9a5151ea00 --- /dev/null +++ b/python/pyspark/serializers.py @@ -0,0 +1,78 @@ +import struct +import cPickle + + +class Batch(object): + """ + Used to store multiple RDD entries as a single Java object. + + This relieves us from having to explicitly track whether an RDD + is stored as batches of objects and avoids problems when processing + the union() of batched and unbatched RDDs (e.g. the union() of textFile() + with another RDD). + """ + def __init__(self, items): + self.items = items + + +def batched(iterator, batchSize): + if batchSize == -1: # unlimited batch size + yield Batch(list(iterator)) + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == batchSize: + yield Batch(items) + items = [] + count = 0 + if items: + yield Batch(items) + + +def dump_pickle(obj): + return cPickle.dumps(obj, 2) + + +load_pickle = cPickle.loads + + +def read_long(stream): + length = stream.read(8) + if length == "": + raise EOFError + return struct.unpack("!q", length)[0] + + +def read_int(stream): + length = stream.read(4) + if length == "": + raise EOFError + return struct.unpack("!i", length)[0] + +def write_with_length(obj, stream): + stream.write(struct.pack("!i", len(obj))) + stream.write(obj) + + +def read_with_length(stream): + length = read_int(stream) + obj = stream.read(length) + if obj == "": + raise EOFError + return obj + + +def read_from_pickle_file(stream): + try: + while True: + obj = load_pickle(read_with_length(stream)) + if type(obj) == Batch: # We don't care about inheritance + for item in obj.items: + yield item + else: + yield obj + except EOFError: + return diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py new file mode 100644 index 0000000000..bd39b0283f --- /dev/null +++ b/python/pyspark/shell.py @@ -0,0 +1,33 @@ +""" +An interactive shell. +""" +import optparse # I prefer argparse, but it's not included with Python < 2.7 +import code +import sys + +from pyspark.context import SparkContext + + +def main(master='local', ipython=False): + sc = SparkContext(master, 'PySparkShell') + user_ns = {'sc' : sc} + banner = "Spark context avaiable as sc." + if ipython: + import IPython + IPython.embed(user_ns=user_ns, banner2=banner) + else: + print banner + code.interact(local=user_ns) + + +if __name__ == '__main__': + usage = "usage: %prog [options] master" + parser = optparse.OptionParser(usage=usage) + parser.add_option("-i", "--ipython", help="Run IPython shell", + action="store_true") + (options, args) = parser.parse_args() + if len(sys.argv) > 1: + master = args[0] + else: + master = 'local' + main(master, options.ipython) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py new file mode 100644 index 0000000000..9f6b507dbd --- /dev/null +++ b/python/pyspark/worker.py @@ -0,0 +1,40 @@ +""" +Worker that receives input from Piped RDD. +""" +import sys +from base64 import standard_b64decode +# CloudPickler needs to be imported so that depicklers are registered using the +# copy_reg module. +from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.cloudpickle import CloudPickler +from pyspark.serializers import write_with_length, read_with_length, \ + read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file + + +# Redirect stdout to stderr so that users must return values from functions. +old_stdout = sys.stdout +sys.stdout = sys.stderr + + +def load_obj(): + return load_pickle(standard_b64decode(sys.stdin.readline().strip())) + + +def main(): + num_broadcast_variables = read_int(sys.stdin) + for _ in range(num_broadcast_variables): + bid = read_long(sys.stdin) + value = read_with_length(sys.stdin) + _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) + func = load_obj() + bypassSerializer = load_obj() + if bypassSerializer: + dumps = lambda x: x + else: + dumps = dump_pickle + for obj in func(read_from_pickle_file(sys.stdin)): + write_with_length(dumps(obj), old_stdout) + + +if __name__ == '__main__': + main() diff --git a/run b/run index ed788c4db3..08e2b2434b 100755 --- a/run +++ b/run @@ -63,7 +63,7 @@ CORE_DIR="$FWDIR/core" REPL_DIR="$FWDIR/repl" EXAMPLES_DIR="$FWDIR/examples" BAGEL_DIR="$FWDIR/bagel" -PYSPARK_DIR="$FWDIR/pyspark" +PYSPARK_DIR="$FWDIR/python" # Build up classpath CLASSPATH="$SPARK_CLASSPATH" diff --git a/run-pyspark b/run-pyspark new file mode 100755 index 0000000000..deb0d708b3 --- /dev/null +++ b/run-pyspark @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +# Figure out where the Scala framework is installed +FWDIR="$(cd `dirname $0`; pwd)" + +# Export this as SPARK_HOME +export SPARK_HOME="$FWDIR" + +# Load environment variables from conf/spark-env.sh, if it exists +if [ -e $FWDIR/conf/spark-env.sh ] ; then + . $FWDIR/conf/spark-env.sh +fi + +# Figure out which Python executable to use +if [ -z "$PYSPARK_PYTHON" ] ; then + PYSPARK_PYTHON="python" +fi +export PYSPARK_PYTHON + +# Add the PySpark classes to the Python path: +export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH + +# Launch with `scala` by default: +if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then + export SPARK_LAUNCH_WITH_SCALA=1 +fi + +exec "$PYSPARK_PYTHON" "$@" diff --git a/run2.cmd b/run2.cmd index 9c50804e69..83464b1166 100644 --- a/run2.cmd +++ b/run2.cmd @@ -34,7 +34,7 @@ set CORE_DIR=%FWDIR%core set REPL_DIR=%FWDIR%repl set EXAMPLES_DIR=%FWDIR%examples set BAGEL_DIR=%FWDIR%bagel -set PYSPARK_DIR=%FWDIR%pyspark +set PYSPARK_DIR=%FWDIR%python rem Build up classpath set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes -- cgit v1.2.3 From 3dc87dd923578f20f2c6945be7d8c03797e76237 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 1 Jan 2013 16:38:04 -0800 Subject: Fixed compilation bug in RDDSuite created during merge for mesos/master. --- core/src/test/scala/spark/RDDSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index eab09956bb..e5a59dc7d6 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -105,9 +105,9 @@ class RDDSuite extends FunSuite with BeforeAndAfter { sc = new SparkContext("local", "test") val onlySplit = new Split { override def index: Int = 0 } var shouldFail = true - val rdd = new RDD[Int](sc) { - override def splits: Array[Split] = Array(onlySplit) - override val dependencies = List[Dependency[_]]() + val rdd = new RDD[Int](sc, Nil) { + override def getSplits: Array[Split] = Array(onlySplit) + override val getDependencies = List[Dependency[_]]() override def compute(split: Split, context: TaskContext): Iterator[Int] = { if (shouldFail) { throw new Exception("injected failure") -- cgit v1.2.3 From 33beba39656fc64984db09a82fc69ca4edcc02d4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 3 Jan 2013 14:52:21 -0800 Subject: Change PySpark RDD.take() to not call iterator(). --- core/src/main/scala/spark/api/python/PythonRDD.scala | 4 ++++ python/pyspark/context.py | 1 + python/pyspark/rdd.py | 11 +++++------ 3 files changed, 10 insertions(+), 6 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index cf60d14f03..79d824d494 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -10,6 +10,7 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast import spark._ import spark.rdd.PipedRDD +import java.util private[spark] class PythonRDD[T: ClassManifest]( @@ -216,6 +217,9 @@ private[spark] object PythonRDD { } file.close() } + + def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] = + rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head } private object Pickle { diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6172d69dcf..4439356c1f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,6 +21,7 @@ class SparkContext(object): jvm = gateway.jvm _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile + _takePartition = jvm.PythonRDD.takePartition def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cbffb6cc1f..4ba417b2a2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -328,18 +328,17 @@ class RDD(object): a lot of partitions are required. In that case, use L{collect} to get the whole RDD instead. - >>> sc.parallelize([2, 3, 4, 5, 6]).take(2) + >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) [2, 3] >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) [2, 3, 4, 5, 6] """ items = [] - splits = self._jrdd.splits() - taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) - while len(items) < num and splits: - split = splits.pop(0) - iterator = self._jrdd.iterator(split, taskContext) + for partition in range(self._jrdd.splits().size()): + iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) items.extend(self._collect_iterator_through_file(iterator)) + if len(items) >= num: + break return items[:num] def first(self): -- cgit v1.2.3 From 8d57c78c83f74e45ce3c119e2e3915d5eac264e7 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 5 Jan 2013 10:54:05 -0600 Subject: Add PairRDDFunctions.keys and values. --- core/src/main/scala/spark/PairRDDFunctions.scala | 10 ++++++++++ core/src/test/scala/spark/ShuffleSuite.scala | 7 +++++++ 2 files changed, 17 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 413c944a66..ce48cea903 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -615,6 +615,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( writer.cleanup() } + /** + * Return an RDD with the keys of each tuple. + */ + def keys: RDD[K] = self.map(_._1) + + /** + * Return an RDD with the values of each tuple. + */ + def values: RDD[V] = self.map(_._2) + private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 8170100f1d..5a867016f2 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -216,6 +216,13 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { // Test that a shuffle on the file works, because this used to be a bug assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) } + + test("kesy and values") { + sc = new SparkContext("local", "test") + val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) + assert(rdd.keys.collect().toList === List(1, 2)) + assert(rdd.values.collect().toList === List("a", "b")) + } } object ShuffleSuite { -- cgit v1.2.3 From f4e6b9361ffeec1018d5834f09db9fd86f2ba7bd Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 4 Jan 2013 22:43:22 -0600 Subject: Add RDD.collect(PartialFunction). --- core/src/main/scala/spark/RDD.scala | 7 +++++++ core/src/test/scala/spark/RDDSuite.scala | 1 + 2 files changed, 8 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7e38583391..5163c80134 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -329,6 +329,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial */ def toArray(): Array[T] = collect() + /** + * Return an RDD that contains all matching values by applying `f`. + */ + def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = { + filter(f.isDefinedAt).map(f) + } + /** * Reduces the elements of this RDD using the specified associative binary operator. */ diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 45e6c5f840..872b06fd08 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -35,6 +35,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) + assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) -- cgit v1.2.3 From 6a0db3b449a829f3e5cdf7229f6ee564268be1df Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 5 Jan 2013 12:56:17 -0600 Subject: Fix typo. --- core/src/test/scala/spark/ShuffleSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 5a867016f2..bebb8ebe86 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -217,7 +217,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) } - test("kesy and values") { + test("keys and values") { sc = new SparkContext("local", "test") val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) assert(rdd.keys.collect().toList === List(1, 2)) -- cgit v1.2.3 From 1fdb6946b5d076ed0f1b4d2bca2a20b6cd22cbc3 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 5 Jan 2013 13:07:59 -0600 Subject: Add RDD.tupleBy. --- core/src/main/scala/spark/RDD.scala | 7 +++++++ core/src/test/scala/spark/RDDSuite.scala | 1 + 2 files changed, 8 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7e38583391..7aa4b0a173 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -510,6 +510,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial .saveAsSequenceFile(path) } + /** + * Tuples the elements of this RDD by applying `f`. + */ + def tupleBy[K](f: T => K): RDD[(K, T)] = { + map(x => (f(x), x)) + } + /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 45e6c5f840..7832884224 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -35,6 +35,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) + assert(nums.tupleBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) -- cgit v1.2.3 From ecf9c0890160c69f1b64b36fa8fdea2f6dd973eb Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 5 Jan 2013 20:54:08 -0500 Subject: Fix Accumulators in Java, and add a test for them --- core/src/main/scala/spark/Accumulators.scala | 18 ++++++++- core/src/main/scala/spark/SparkContext.scala | 7 ++-- .../scala/spark/api/java/JavaSparkContext.scala | 23 +++++++---- core/src/test/scala/spark/JavaAPISuite.java | 44 ++++++++++++++++++++++ 4 files changed, 79 insertions(+), 13 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index bacd0ace37..6280f25391 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -38,14 +38,28 @@ class Accumulable[R, T] ( */ def += (term: T) { value_ = param.addAccumulator(value_, term) } + /** + * Add more data to this accumulator / accumulable + * @param term the data to add + */ + def add(term: T) { value_ = param.addAccumulator(value_, term) } + /** * Merge two accumulable objects together - * + * * Normally, a user will not want to use this version, but will instead call `+=`. - * @param term the other Accumulable that will get merged with this + * @param term the other `R` that will get merged with this */ def ++= (term: R) { value_ = param.addInPlace(value_, term)} + /** + * Merge two accumulable objects together + * + * Normally, a user will not want to use this version, but will instead call `add`. + * @param term the other `R` that will get merged with this + */ + def merge(term: R) { value_ = param.addInPlace(value_, term)} + /** * Access the accumulator's current value; only allowed on master. */ diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4fd81bc63b..bbf8272eb3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -382,11 +382,12 @@ class SparkContext( new Accumulator(initialValue, param) /** - * Create an [[spark.Accumulable]] shared variable, with a `+=` method + * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. + * Only the master can access the accumuable's `value`. * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ - def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = + def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = new Accumulable(initialValue, param) /** @@ -404,7 +405,7 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal) + def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** * Add a file to be downloaded into the working directory of this Spark job on every node. diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index b7725313c4..bf9ad7a200 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import spark.{Accumulator, AccumulatorParam, RDD, SparkContext} +import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext} import spark.SparkContext.IntAccumulatorParam import spark.SparkContext.DoubleAccumulatorParam import spark.broadcast.Broadcast @@ -265,25 +265,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def intAccumulator(initialValue: Int): Accumulator[Int] = - sc.accumulator(initialValue)(IntAccumulatorParam) + def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = + sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] /** * Create an [[spark.Accumulator]] double variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def doubleAccumulator(initialValue: Double): Accumulator[Double] = - sc.accumulator(initialValue)(DoubleAccumulatorParam) + def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = + sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) + /** + * Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can + * "add" values with `add`. Only the master can access the accumuable's `value`. + */ + def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = + sc.accumulable(initialValue)(param) + /** * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 33d5fc2d89..b99e790093 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -581,4 +581,48 @@ public class JavaAPISuite implements Serializable { JavaPairRDD zipped = rdd.zip(doubles); zipped.count(); } + + @Test + public void accumulators() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + + final Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + intAccum.add(x); + } + }); + Assert.assertEquals((Integer) 25, intAccum.value()); + + final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + doubleAccum.add((double) x); + } + }); + Assert.assertEquals((Double) 25.0, doubleAccum.value()); + + // Try a custom accumulator type + AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + public Float addInPlace(Float r, Float t) { + return r + t; + } + + public Float addAccumulator(Float r, Float t) { + return r + t; + } + + public Float zero(Float initialValue) { + return 0.0f; + } + }; + + final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + floatAccum.add((float) x); + } + }); + Assert.assertEquals((Float) 25.0f, floatAccum.value()); + } } -- cgit v1.2.3 From 86af64b0a6fde5a6418727a77b43bdfeda1b81cd Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 5 Jan 2013 20:54:08 -0500 Subject: Fix Accumulators in Java, and add a test for them --- core/src/main/scala/spark/Accumulators.scala | 18 ++++++++- core/src/main/scala/spark/SparkContext.scala | 7 ++-- .../scala/spark/api/java/JavaSparkContext.scala | 23 +++++++---- core/src/test/scala/spark/JavaAPISuite.java | 44 ++++++++++++++++++++++ 4 files changed, 79 insertions(+), 13 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index bacd0ace37..6280f25391 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -38,14 +38,28 @@ class Accumulable[R, T] ( */ def += (term: T) { value_ = param.addAccumulator(value_, term) } + /** + * Add more data to this accumulator / accumulable + * @param term the data to add + */ + def add(term: T) { value_ = param.addAccumulator(value_, term) } + /** * Merge two accumulable objects together - * + * * Normally, a user will not want to use this version, but will instead call `+=`. - * @param term the other Accumulable that will get merged with this + * @param term the other `R` that will get merged with this */ def ++= (term: R) { value_ = param.addInPlace(value_, term)} + /** + * Merge two accumulable objects together + * + * Normally, a user will not want to use this version, but will instead call `add`. + * @param term the other `R` that will get merged with this + */ + def merge(term: R) { value_ = param.addInPlace(value_, term)} + /** * Access the accumulator's current value; only allowed on master. */ diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4fd81bc63b..bbf8272eb3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -382,11 +382,12 @@ class SparkContext( new Accumulator(initialValue, param) /** - * Create an [[spark.Accumulable]] shared variable, with a `+=` method + * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. + * Only the master can access the accumuable's `value`. * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ - def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = + def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = new Accumulable(initialValue, param) /** @@ -404,7 +405,7 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal) + def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** * Add a file to be downloaded into the working directory of this Spark job on every node. diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index b7725313c4..bf9ad7a200 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import spark.{Accumulator, AccumulatorParam, RDD, SparkContext} +import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext} import spark.SparkContext.IntAccumulatorParam import spark.SparkContext.DoubleAccumulatorParam import spark.broadcast.Broadcast @@ -265,25 +265,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def intAccumulator(initialValue: Int): Accumulator[Int] = - sc.accumulator(initialValue)(IntAccumulatorParam) + def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = + sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] /** * Create an [[spark.Accumulator]] double variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def doubleAccumulator(initialValue: Double): Accumulator[Double] = - sc.accumulator(initialValue)(DoubleAccumulatorParam) + def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = + sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) + /** + * Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can + * "add" values with `add`. Only the master can access the accumuable's `value`. + */ + def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = + sc.accumulable(initialValue)(param) + /** * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 33d5fc2d89..b99e790093 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -581,4 +581,48 @@ public class JavaAPISuite implements Serializable { JavaPairRDD zipped = rdd.zip(doubles); zipped.count(); } + + @Test + public void accumulators() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + + final Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + intAccum.add(x); + } + }); + Assert.assertEquals((Integer) 25, intAccum.value()); + + final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + doubleAccum.add((double) x); + } + }); + Assert.assertEquals((Double) 25.0, doubleAccum.value()); + + // Try a custom accumulator type + AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + public Float addInPlace(Float r, Float t) { + return r + t; + } + + public Float addAccumulator(Float r, Float t) { + return r + t; + } + + public Float zero(Float initialValue) { + return 0.0f; + } + }; + + final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + floatAccum.add((float) x); + } + }); + Assert.assertEquals((Float) 25.0f, floatAccum.value()); + } } -- cgit v1.2.3 From 0982572519655354b10987de4f68e29b8331bd2a Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 5 Jan 2013 22:11:28 -0500 Subject: Add methods called just 'accumulator' for int/double in Java API --- core/src/main/scala/spark/api/java/JavaSparkContext.scala | 13 +++++++++++++ core/src/test/scala/spark/JavaAPISuite.java | 4 ++-- 2 files changed, 15 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index bf9ad7a200..88ab2846be 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -277,6 +277,19 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] + /** + * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + */ + def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue) + + /** + * Create an [[spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + */ + def accumulator(initialValue: Double): Accumulator[java.lang.Double] = + doubleAccumulator(initialValue) + /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index b99e790093..912f8de05d 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -586,7 +586,7 @@ public class JavaAPISuite implements Serializable { public void accumulators() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - final Accumulator intAccum = sc.intAccumulator(10); + final Accumulator intAccum = sc.accumulator(10); rdd.foreach(new VoidFunction() { public void call(Integer x) { intAccum.add(x); @@ -594,7 +594,7 @@ public class JavaAPISuite implements Serializable { }); Assert.assertEquals((Integer) 25, intAccum.value()); - final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + final Accumulator doubleAccum = sc.accumulator(10.0); rdd.foreach(new VoidFunction() { public void call(Integer x) { doubleAccum.add((double) x); -- cgit v1.2.3 From 8fd3a70c188182105f81f5143ec65e74663582d5 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 5 Jan 2013 22:46:45 -0500 Subject: Add PairRDD.keys() and values() to Java API --- core/src/main/scala/spark/api/java/JavaPairRDD.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index 5c2be534ff..8ce32e0e2f 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -471,6 +471,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x) fromRDD(new OrderedRDDFunctions(rdd).sortByKey(ascending)) } + + /** + * Return an RDD with the keys of each tuple. + */ + def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1)) + + /** + * Return an RDD with the values of each tuple. + */ + def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2)) } object JavaPairRDD { -- cgit v1.2.3 From 8dc06069fe2330c3ee0fcaaeb0ae6e627a5887c3 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sun, 6 Jan 2013 15:21:45 -0600 Subject: Rename RDD.tupleBy to keyBy. --- core/src/main/scala/spark/RDD.scala | 4 ++-- core/src/test/scala/spark/RDDSuite.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7aa4b0a173..5ce524c0e7 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -511,9 +511,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial } /** - * Tuples the elements of this RDD by applying `f`. + * Creates tuples of the elements in this RDD by applying `f`. */ - def tupleBy[K](f: T => K): RDD[(K, T)] = { + def keyBy[K](f: T => K): RDD[(K, T)] = { map(x => (f(x), x)) } diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 7832884224..77bff8aba1 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -35,7 +35,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) - assert(nums.tupleBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) + assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) -- cgit v1.2.3 From 1346126485444afc065bf4951c4bedebe5c95ce4 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 12:11:27 -0800 Subject: Changed cleanup to clearOldValues for TimeStampedHashMap and TimeStampedHashSet. --- core/src/main/scala/spark/CacheTracker.scala | 4 ++-- core/src/main/scala/spark/MapOutputTracker.scala | 4 ++-- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 6 +++--- core/src/main/scala/spark/scheduler/ResultTask.scala | 2 +- core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 2 +- core/src/main/scala/spark/util/TimeStampedHashMap.scala | 7 +++++-- core/src/main/scala/spark/util/TimeStampedHashSet.scala | 5 ++++- 7 files changed, 18 insertions(+), 12 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 7d320c4fe5..86ad737583 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -39,7 +39,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] - private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.cleanup) + private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.clearOldValues) private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) @@ -120,7 +120,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b // Remembers which splits are currently being loaded (on worker nodes) val loading = new HashSet[String] - val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.cleanup) + val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.clearOldValues) // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 5ebdba0fc8..a2fa2d1ea7 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -178,8 +178,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } def cleanup(cleanupTime: Long) { - mapStatuses.cleanup(cleanupTime) - cachedSerializedStatuses.cleanup(cleanupTime) + mapStatuses.clearOldValues(cleanupTime) + cachedSerializedStatuses.clearOldValues(cleanupTime) } def stop() { diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 9387ba19a3..59f2099e91 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -599,15 +599,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def cleanup(cleanupTime: Long) { var sizeBefore = idToStage.size - idToStage.cleanup(cleanupTime) + idToStage.clearOldValues(cleanupTime) logInfo("idToStage " + sizeBefore + " --> " + idToStage.size) sizeBefore = shuffleToMapStage.size - shuffleToMapStage.cleanup(cleanupTime) + shuffleToMapStage.clearOldValues(cleanupTime) logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) sizeBefore = pendingTasks.size - pendingTasks.cleanup(cleanupTime) + pendingTasks.clearOldValues(cleanupTime) logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) } diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index 7ec6564105..74a63c1af1 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -12,7 +12,7 @@ private[spark] object ResultTask { // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.cleanup) + val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues) def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index feb63abb61..19f5328eee 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -23,7 +23,7 @@ private[spark] object ShuffleMapTask { // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.cleanup) + val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index 7e785182ea..bb7c5c01c8 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -7,7 +7,7 @@ import scala.collection.mutable.Map /** * This is a custom implementation of scala.collection.mutable.Map which stores the insertion * time stamp along with each key-value pair. Key-value pairs that are older than a particular - * threshold time can them be removed using the cleanup method. This is intended to be a drop-in + * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in * replacement of scala.collection.mutable.HashMap. */ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { @@ -74,7 +74,10 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { } } - def cleanup(threshTime: Long) { + /** + * Removes old key-value pairs that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { val iterator = internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() diff --git a/core/src/main/scala/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/spark/util/TimeStampedHashSet.scala index 539dd75844..5f1cc93752 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashSet.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashSet.scala @@ -52,7 +52,10 @@ class TimeStampedHashSet[A] extends Set[A] { } } - def cleanup(threshTime: Long) { + /** + * Removes old values that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { val iterator = internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() -- cgit v1.2.3 From 9c32f300fb4151a2b563bf3d2e46469722e016e1 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 7 Jan 2013 16:50:23 -0500 Subject: Add Accumulable.setValue for easier use in Java --- core/src/main/scala/spark/Accumulators.scala | 20 +++++++++++++++----- core/src/test/scala/spark/JavaAPISuite.java | 4 ++++ 2 files changed, 19 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index 6280f25391..b644aba5f8 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -63,9 +63,12 @@ class Accumulable[R, T] ( /** * Access the accumulator's current value; only allowed on master. */ - def value = { - if (!deserialized) value_ - else throw new UnsupportedOperationException("Can't read accumulator value in task") + def value: R = { + if (!deserialized) { + value_ + } else { + throw new UnsupportedOperationException("Can't read accumulator value in task") + } } /** @@ -82,10 +85,17 @@ class Accumulable[R, T] ( /** * Set the accumulator's value; only allowed on master. */ - def value_= (r: R) { - if (!deserialized) value_ = r + def value_= (newValue: R) { + if (!deserialized) value_ = newValue else throw new UnsupportedOperationException("Can't assign accumulator value in task") } + + /** + * Set the accumulator's value; only allowed on master + */ + def setValue(newValue: R) { + this.value = newValue + } // Called by Java when deserializing an object private def readObject(in: ObjectInputStream) { diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 912f8de05d..0817d1146c 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -624,5 +624,9 @@ public class JavaAPISuite implements Serializable { } }); Assert.assertEquals((Float) 25.0f, floatAccum.value()); + + // Test the setValue method + floatAccum.setValue(5.0f); + Assert.assertEquals((Float) 5.0f, floatAccum.value()); } } -- cgit v1.2.3 From 237bac36e9dca8828192994dad323b8da1619267 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 14:37:21 -0800 Subject: Renamed examples and added documentation. --- core/src/main/scala/spark/RDDCheckpointData.scala | 4 +- docs/streaming-programming-guide.md | 14 ++-- .../spark/streaming/examples/FileStream.scala | 46 ------------- .../examples/FileStreamWithCheckpoint.scala | 75 ---------------------- .../spark/streaming/examples/FlumeEventCount.scala | 2 +- .../scala/spark/streaming/examples/GrepRaw.scala | 32 --------- .../spark/streaming/examples/HdfsWordCount.scala | 36 +++++++++++ .../streaming/examples/NetworkWordCount.scala | 36 +++++++++++ .../spark/streaming/examples/RawNetworkGrep.scala | 46 +++++++++++++ .../streaming/examples/TopKWordCountRaw.scala | 49 -------------- .../spark/streaming/examples/WordCountHdfs.scala | 25 -------- .../streaming/examples/WordCountNetwork.scala | 25 -------- .../spark/streaming/examples/WordCountRaw.scala | 43 ------------- .../scala/spark/streaming/StreamingContext.scala | 38 ++++++++--- .../spark/streaming/dstream/FileInputDStream.scala | 16 ++--- .../scala/spark/streaming/InputStreamsSuite.scala | 2 +- 16 files changed, 163 insertions(+), 326 deletions(-) delete mode 100644 examples/src/main/scala/spark/streaming/examples/FileStream.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/GrepRaw.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala (limited to 'core') diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index 7af830940f..e270b6312e 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -65,7 +65,7 @@ extends Logging with Serializable { cpRDD = Some(newRDD) rdd.changeDependencies(newRDD) cpState = Checkpointed - RDDCheckpointData.checkpointCompleted() + RDDCheckpointData.clearTaskCaches() logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } } @@ -90,7 +90,7 @@ extends Logging with Serializable { } private[spark] object RDDCheckpointData { - def checkpointCompleted() { + def clearTaskCaches() { ShuffleMapTask.clearCache() ResultTask.clearCache() } diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index fc2ea2ef79..05a88ce7bd 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -187,8 +187,8 @@ Conversely, the computation can be stopped by using ssc.stop() {% endhighlight %} -# Example - WordCountNetwork.scala -A good example to start off is the spark.streaming.examples.WordCountNetwork. This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in /streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala. +# Example - NetworkWordCount.scala +A good example to start off is the spark.streaming.examples.NetworkWordCount. This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in /streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala. {% highlight scala %} import spark.streaming.{Seconds, StreamingContext} @@ -196,7 +196,7 @@ import spark.streaming.StreamingContext._ ... // Create the context and set up a network input stream to receive from a host:port -val ssc = new StreamingContext(args(0), "WordCountNetwork", Seconds(1)) +val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1)) val lines = ssc.networkTextStream(args(1), args(2).toInt) // Split the lines into words, count them, and print some of the counts on the master @@ -214,13 +214,13 @@ To run this example on your local machine, you need to first run a Netcat server $ nc -lk 9999 {% endhighlight %} -Then, in a different terminal, you can start WordCountNetwork by using +Then, in a different terminal, you can start NetworkWordCount by using {% highlight bash %} -$ ./run spark.streaming.examples.WordCountNetwork local[2] localhost 9999 +$ ./run spark.streaming.examples.NetworkWordCount local[2] localhost 9999 {% endhighlight %} -This will make WordCountNetwork connect to the netcat server. Any lines typed in the terminal running the netcat server will be counted and printed on screen. +This will make NetworkWordCount connect to the netcat server. Any lines typed in the terminal running the netcat server will be counted and printed on screen.
print() Prints the contents of this DStream on the driver. At each interval, this will take at most ten elements from the DStream's RDD and print them. Prints first ten elements of every batch of data in a DStream on the driver.
saveAsObjectFile(prefix, [suffix]) saveAsObjectFiles(prefix, [suffix]) Save this DStream's contents as a SequenceFile of serialized objects. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
saveAsTextFile(prefix, suffix) saveAsTextFiles(prefix, [suffix]) Save this DStream's contents as a text files. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
saveAsHadoopFiles(prefix, suffix) saveAsHadoopFiles(prefix, [suffix]) Save this DStream's contents as a Hadoop file. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
- + - + @@ -88,55 +88,60 @@ DStreams support many of the transformations available on normal Spark RDD's: - + - + + + + +
@@ -240,7 +240,7 @@ hello world {% highlight bash %} -# TERMINAL 2: RUNNING WordCountNetwork +# TERMINAL 2: RUNNING NetworkWordCount ... 2012-12-31 18:47:10,446 INFO SparkContext: Job finished: run at ThreadPoolExecutor.java:886, took 0.038817 s ------------------------------------------- diff --git a/examples/src/main/scala/spark/streaming/examples/FileStream.scala b/examples/src/main/scala/spark/streaming/examples/FileStream.scala deleted file mode 100644 index 81938d30d4..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/FileStream.scala +++ /dev/null @@ -1,46 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.StreamingContext -import spark.streaming.StreamingContext._ -import spark.streaming.Seconds -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration - - -object FileStream { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: FileStream ") - System.exit(1) - } - - // Create the context - val ssc = new StreamingContext(args(0), "FileStream", Seconds(1)) - - // Create the new directory - val directory = new Path(args(1)) - val fs = directory.getFileSystem(new Configuration()) - if (fs.exists(directory)) throw new Exception("This directory already exists") - fs.mkdirs(directory) - fs.deleteOnExit(directory) - - // Create the FileInputDStream on the directory and use the - // stream to count words in new files created - val inputStream = ssc.textFileStream(directory.toString) - val words = inputStream.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - - // Creating new files in the directory - val text = "This is a text file" - for (i <- 1 to 30) { - ssc.sc.parallelize((1 to (i * 10)).map(_ => text), 10) - .saveAsTextFile(new Path(directory, i.toString).toString) - Thread.sleep(1000) - } - Thread.sleep(5000) // Waiting for the file to be processed - ssc.stop() - System.exit(0) - } -} \ No newline at end of file diff --git a/examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala deleted file mode 100644 index b7bc15a1d5..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala +++ /dev/null @@ -1,75 +0,0 @@ -package spark.streaming.examples - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration - -object FileStreamWithCheckpoint { - - def main(args: Array[String]) { - - if (args.size != 3) { - println("FileStreamWithCheckpoint ") - println("FileStreamWithCheckpoint restart ") - System.exit(-1) - } - - val directory = new Path(args(1)) - val checkpointDir = args(2) - - val ssc: StreamingContext = { - - if (args(0) == "restart") { - - // Recreated streaming context from specified checkpoint file - new StreamingContext(checkpointDir) - - } else { - - // Create directory if it does not exist - val fs = directory.getFileSystem(new Configuration()) - if (!fs.exists(directory)) fs.mkdirs(directory) - - // Create new streaming context - val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint", Seconds(1)) - ssc_.checkpoint(checkpointDir) - - // Setup the streaming computation - val inputStream = ssc_.textFileStream(directory.toString) - val words = inputStream.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - - ssc_ - } - } - - // Start the stream computation - startFileWritingThread(directory.toString) - ssc.start() - } - - def startFileWritingThread(directory: String) { - - val fs = new Path(directory).getFileSystem(new Configuration()) - - val fileWritingThread = new Thread() { - override def run() { - val r = new scala.util.Random() - val text = "This is a sample text file with a random number " - while(true) { - val number = r.nextInt() - val file = new Path(directory, number.toString) - val fos = fs.create(file) - fos.writeChars(text + number) - fos.close() - println("Created text file " + file) - Thread.sleep(1000) - } - } - } - fileWritingThread.start() - } - -} diff --git a/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala index e60ce483a3..461929fba2 100644 --- a/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala +++ b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala @@ -5,7 +5,7 @@ import spark.storage.StorageLevel import spark.streaming._ /** - * Produce a streaming count of events received from Flume. + * Produces a count of events received from Flume. * * This should be used in conjunction with an AvroSink in Flume. It will start * an Avro server on at the request host:port address and listen for requests. diff --git a/examples/src/main/scala/spark/streaming/examples/GrepRaw.scala b/examples/src/main/scala/spark/streaming/examples/GrepRaw.scala deleted file mode 100644 index 812faa368a..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ /dev/null @@ -1,32 +0,0 @@ -package spark.streaming.examples - -import spark.util.IntParam -import spark.storage.StorageLevel - -import spark.streaming._ -import spark.streaming.util.RawTextHelper._ - -object GrepRaw { - def main(args: Array[String]) { - if (args.length != 5) { - System.err.println("Usage: GrepRaw ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args - - // Create the context - val ssc = new StreamingContext(master, "GrepRaw", Milliseconds(batchMillis)) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - warmUp(ssc.sc) - - - val rawStreams = (1 to numStreams).map(_ => - ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray - val union = ssc.union(rawStreams) - union.filter(_.contains("Alice")).count().foreach(r => - println("Grep count: " + r.collect().mkString)) - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala b/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala new file mode 100644 index 0000000000..8530f5c175 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala @@ -0,0 +1,36 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + + +/** + * Counts words in new text files created in the given directory + * Usage: HdfsWordCount + * is the Spark master URL. + * is the directory that Spark Streaming will use to find and read new text files. + * + * To run this on your local machine on directory `localdir`, run this example + * `$ ./run spark.streaming.examples.HdfsWordCount local[2] localdir` + * Then create a text file in `localdir` and the words in the file will get counted. + */ +object HdfsWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: HdfsWordCount ") + System.exit(1) + } + + // Create the context + val ssc = new StreamingContext(args(0), "HdfsWordCount", Seconds(2)) + + // Create the FileInputDStream on the directory and use the + // stream to count words in new files created + val lines = ssc.textFileStream(args(1)) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} + diff --git a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala new file mode 100644 index 0000000000..43c01d5db2 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala @@ -0,0 +1,36 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Usage: NetworkWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ ./run spark.streaming.examples.NetworkWordCount local[2] localhost 9999` + */ +object NetworkWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: NetworkWordCount \n" + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1)) + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + val lines = ssc.networkTextStream(args(1), args(2).toInt) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala new file mode 100644 index 0000000000..2eec777c54 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala @@ -0,0 +1,46 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel + +import spark.streaming._ +import spark.streaming.util.RawTextHelper + +/** + * Receives text from multiple rawNetworkStreams and counts how many '\n' delimited + * lines have the word 'the' in them. This is useful for benchmarking purposes. This + * will only work with spark.streaming.util.RawTextSender running on all worker nodes + * and with Spark using Kryo serialization (set Java property "spark.serializer" to + * "spark.KryoSerializer"). + * Usage: RawNetworkGrep + * is the Spark master URL + * is the number rawNetworkStreams, which should be same as number + * of work nodes in the cluster + * is "localhost". + * is the port on which RawTextSender is running in the worker nodes. + * is the Spark Streaming batch duration in milliseconds. + */ + +object RawNetworkGrep { + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("Usage: RawNetworkGrep ") + System.exit(1) + } + + val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args + + // Create the context + val ssc = new StreamingContext(master, "RawNetworkGrep", Milliseconds(batchMillis)) + + // Warm up the JVMs on master and slave for JIT compilation to kick in + RawTextHelper.warmUp(ssc.sc) + + val rawStreams = (1 to numStreams).map(_ => + ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray + val union = ssc.union(rawStreams) + union.filter(_.contains("the")).count().foreach(r => + println("Grep count: " + r.collect().mkString)) + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala deleted file mode 100644 index 338834bc3c..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ /dev/null @@ -1,49 +0,0 @@ -package spark.streaming.examples - -import spark.storage.StorageLevel -import spark.util.IntParam - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import spark.streaming.util.RawTextHelper._ - -import java.util.UUID - -object TopKWordCountRaw { - - def main(args: Array[String]) { - if (args.length != 4) { - System.err.println("Usage: WordCountRaw <# streams> ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args - val k = 10 - - // Create the context, and set the checkpoint directory. - // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts - // periodically to HDFS - val ssc = new StreamingContext(master, "TopKWordCountRaw", Seconds(1)) - ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - /*warmUp(ssc.sc)*/ - - // Set up the raw network streams that will connect to localhost:port to raw test - // senders on the slaves and generate top K words of last 30 seconds - val lines = (1 to numStreams).map(_ => { - ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) - }) - val union = ssc.union(lines) - val counts = union.mapPartitions(splitAndCountPartitions) - val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreach(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " words from partial top words") - println("Top " + k + " words are " + topK(collectedCounts.toIterator, k).mkString(",")) - }) - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala b/examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala deleted file mode 100644 index 867a8f42c4..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala +++ /dev/null @@ -1,25 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - -object WordCountHdfs { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: WordCountHdfs ") - System.exit(1) - } - - // Create the context - val ssc = new StreamingContext(args(0), "WordCountHdfs", Seconds(2)) - - // Create the FileInputDStream on the directory and use the - // stream to count words in new files created - val lines = ssc.textFileStream(args(1)) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } -} - diff --git a/examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala b/examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala deleted file mode 100644 index eadda60563..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala +++ /dev/null @@ -1,25 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - -object WordCountNetwork { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: WordCountNetwork \n" + - "In local mode, should be 'local[n]' with n > 1") - System.exit(1) - } - - // Create the context and set the batch size - val ssc = new StreamingContext(args(0), "WordCountNetwork", Seconds(1)) - - // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.networkTextStream(args(1), args(2).toInt) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala deleted file mode 100644 index d93335a8ce..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ /dev/null @@ -1,43 +0,0 @@ -package spark.streaming.examples - -import spark.storage.StorageLevel -import spark.util.IntParam - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import spark.streaming.util.RawTextHelper._ - -import java.util.UUID - -object WordCountRaw { - - def main(args: Array[String]) { - if (args.length != 4) { - System.err.println("Usage: WordCountRaw <# streams> ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args - - // Create the context, and set the checkpoint directory. - // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts - // periodically to HDFS - val ssc = new StreamingContext(master, "WordCountRaw", Seconds(1)) - ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - warmUp(ssc.sc) - - // Set up the raw network streams that will connect to localhost:port to raw test - // senders on the slaves and generate count of words of last 30 seconds - val lines = (1 to numStreams).map(_ => { - ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) - }) - val union = ssc.union(lines) - val counts = union.mapPartitions(splitAndCountPartitions) - val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) - windowedCounts.foreach(r => println("# unique words = " + r.count())) - - ssc.start() - } -} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 7256e41af9..215246ba2e 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -154,7 +154,7 @@ class StreamingContext private ( storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[T] = { val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } @@ -192,7 +192,7 @@ class StreamingContext private ( storageLevel: StorageLevel ): DStream[T] = { val inputStream = new SocketInputDStream[T](this, hostname, port, converter, storageLevel) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } @@ -208,7 +208,7 @@ class StreamingContext private ( storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): DStream[SparkFlumeEvent] = { val inputStream = new FlumeInputDStream(this, hostname, port, storageLevel) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } @@ -228,13 +228,14 @@ class StreamingContext private ( storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): DStream[T] = { val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } /** * Creates a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them using the given key-value types and input format. + * File names starting with . are ignored. * @param directory HDFS directory to monitor for new file * @tparam K Key type for reading HDFS file * @tparam V Value type for reading HDFS file @@ -244,16 +245,37 @@ class StreamingContext private ( K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K, V]: ClassManifest - ](directory: String): DStream[(K, V)] = { + ] (directory: String): DStream[(K, V)] = { val inputStream = new FileInputDStream[K, V, F](this, directory) - graph.addInputStream(inputStream) + registerInputStream(inputStream) + inputStream + } + + /** + * Creates a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * @param directory HDFS directory to monitor for new file + * @param filter Function to filter paths to process + * @param newFilesOnly Should process only new files and ignore existing files in the directory + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[ + K: ClassManifest, + V: ClassManifest, + F <: NewInputFormat[K, V]: ClassManifest + ] (directory: String, filter: Path => Boolean, newFilesOnly: Boolean): DStream[(K, V)] = { + val inputStream = new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly) + registerInputStream(inputStream) inputStream } + /** * Creates a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them as text files (using key as LongWritable, value - * as Text and input format as TextInputFormat). + * as Text and input format as TextInputFormat). File names starting with . are ignored. * @param directory HDFS directory to monitor for new file */ def textFileStream(directory: String): DStream[String] = { @@ -274,7 +296,7 @@ class StreamingContext private ( defaultRDD: RDD[T] = null ): DStream[T] = { val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala index cf72095324..1e6ad84b44 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala @@ -14,7 +14,7 @@ private[streaming] class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( @transient ssc_ : StreamingContext, directory: String, - filter: PathFilter = FileInputDStream.defaultPathFilter, + filter: Path => Boolean = FileInputDStream.defaultFilter, newFilesOnly: Boolean = true) extends InputDStream[(K, V)](ssc_) { @@ -60,7 +60,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K val latestModTimeFiles = new HashSet[String]() def accept(path: Path): Boolean = { - if (!filter.accept(path)) { + if (!filter(path)) { return false } else { val modTime = fs.getFileStatus(path).getModificationTime() @@ -95,16 +95,8 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K } } +private[streaming] object FileInputDStream { - val defaultPathFilter = new PathFilter with Serializable { - def accept(path: Path): Boolean = { - val file = path.getName() - if (file.startsWith(".") || file.endsWith("_tmp")) { - return false - } else { - return true - } - } - } + def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") } diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 76b528bec3..00ee903c1e 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -318,7 +318,7 @@ class TestServer(port: Int) extends Logging { } } } catch { - case e: SocketException => println(e) + case e: SocketException => logInfo(e) } finally { logInfo("Connection closed") if (!clientSocket.isClosed) clientSocket.close() -- cgit v1.2.3 From 3b0a3b89ac508b57b8afbd1ca7024ee558a5d1af Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 14:55:49 -0800 Subject: Added better docs for RDDCheckpointData --- core/src/main/scala/spark/RDDCheckpointData.scala | 10 +++++++++- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 4 ++++ 2 files changed, 13 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index e270b6312e..d845a522e4 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -14,15 +14,23 @@ private[spark] object CheckpointState extends Enumeration { } /** - * This class contains all the information of the regarding RDD checkpointing. + * This class contains all the information related to RDD checkpointing. Each instance of this class + * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as, + * manages the post-checkpoint state by providing the updated splits, iterator and preferred locations + * of the checkpointed RDD. */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) extends Logging with Serializable { import CheckpointState._ + // The checkpoint state of the associated RDD. var cpState = Initialized + + // The file to which the associated RDD has been checkpointed to @transient var cpFile: Option[String] = None + + // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. @transient var cpRDD: Option[RDD[T]] = None // Mark the RDD for checkpointing diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 1a88d402c3..86c63ca2f4 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -13,6 +13,10 @@ private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends override val index: Int = idx } +/** + * This RDD represents a RDD checkpoint file (similar to HadoopRDD). + */ +private[spark] class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) extends RDD[T](sc, Nil) { -- cgit v1.2.3 From f8d579a0c05b7d29b59e541b483ded471d14ec17 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 27 Dec 2012 13:30:07 -0800 Subject: Remove dependencies on sun jvm classes. Instead use reflection to infer HotSpot options and total physical memory size --- core/src/main/scala/spark/SizeEstimator.scala | 13 ++++++++++--- .../spark/deploy/worker/WorkerArguments.scala | 22 +++++++++++++++++++--- 2 files changed, 29 insertions(+), 6 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala index 7c3e8640e9..d4e1157250 100644 --- a/core/src/main/scala/spark/SizeEstimator.scala +++ b/core/src/main/scala/spark/SizeEstimator.scala @@ -9,7 +9,6 @@ import java.util.Random import javax.management.MBeanServer import java.lang.management.ManagementFactory -import com.sun.management.HotSpotDiagnosticMXBean import scala.collection.mutable.ArrayBuffer @@ -76,12 +75,20 @@ private[spark] object SizeEstimator extends Logging { if (System.getProperty("spark.test.useCompressedOops") != null) { return System.getProperty("spark.test.useCompressedOops").toBoolean } + try { val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic" val server = ManagementFactory.getPlatformMBeanServer() + + // NOTE: This should throw an exception in non-Sun JVMs + val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean") + val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", + Class.forName("java.lang.String")) + val bean = ManagementFactory.newPlatformMXBeanProxy(server, - hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean]) - return bean.getVMOption("UseCompressedOops").getValue.toBoolean + hotSpotMBeanName, hotSpotMBeanClass) + // TODO: We could use reflection on the VMOption returned ? + return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") } catch { case e: Exception => { // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala index 340920025b..37524a7c82 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala @@ -104,9 +104,25 @@ private[spark] class WorkerArguments(args: Array[String]) { } def inferDefaultMemory(): Int = { - val bean = ManagementFactory.getOperatingSystemMXBean - .asInstanceOf[com.sun.management.OperatingSystemMXBean] - val totalMb = (bean.getTotalPhysicalMemorySize / 1024 / 1024).toInt + val ibmVendor = System.getProperty("java.vendor").contains("IBM") + var totalMb = 0 + try { + val bean = ManagementFactory.getOperatingSystemMXBean() + if (ibmVendor) { + val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean") + val method = beanClass.getDeclaredMethod("getTotalPhysicalMemory") + totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt + } else { + val beanClass = Class.forName("com.sun.management.OperatingSystemMXBean") + val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize") + totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt + } + } catch { + case e: Exception => { + totalMb = 2*1024 + System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") + } + } // Leave out 1 GB for the operating system, but don't return a negative memory size math.max(totalMb - 1024, 512) } -- cgit v1.2.3 From aed368a970bbaee4bdf297ba3f6f1b0fa131452c Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sat, 29 Dec 2012 16:23:43 -0800 Subject: Update Hadoop dependency to 1.0.3 as 0.20 has Sun specific dependencies. Also fix SequenceFileRDDFunctions to pick the right type conversion across Hadoop versions --- core/src/main/scala/spark/SequenceFileRDDFunctions.scala | 8 +++++++- project/SparkBuild.scala | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala index a34aee69c1..6b4a11d6d3 100644 --- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala @@ -42,7 +42,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) { classManifest[T].erasure } else { - implicitly[T => Writable].getClass.getMethods()(0).getReturnType + // We get the type of the Writable class by looking at the apply method which converts + // from T to Writable. Since we have two apply methods we filter out the one which + // is of the form "java.lang.Object apply(java.lang.Object)" + implicitly[T => Writable].getClass.getDeclaredMethods().filter( + m => m.getReturnType().toString != "java.lang.Object" && + m.getName() == "apply")(0).getReturnType + } // TODO: use something like WritableConverter to avoid reflection } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 842d0fa96b..7c7c33131a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -10,7 +10,7 @@ import twirl.sbt.TwirlPlugin._ object SparkBuild extends Build { // Hadoop version to build against. For example, "0.20.2", "0.20.205.0", or // "1.0.3" for Apache releases, or "0.20.2-cdh3u5" for Cloudera Hadoop. - val HADOOP_VERSION = "0.20.205.0" + val HADOOP_VERSION = "1.0.3" val HADOOP_MAJOR_VERSION = "1" // For Hadoop 2 versions such as "2.0.0-mr1-cdh4.1.1", set the HADOOP_MAJOR_VERSION to "2" -- cgit v1.2.3 From 77d751731ccd06e161e3ef10540f8165d964282f Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sat, 29 Dec 2012 18:28:00 -0800 Subject: Remove unused BoundedMemoryCache file and associated test case. --- core/src/main/scala/spark/BoundedMemoryCache.scala | 118 --------------------- .../test/scala/spark/BoundedMemoryCacheSuite.scala | 58 ---------- 2 files changed, 176 deletions(-) delete mode 100644 core/src/main/scala/spark/BoundedMemoryCache.scala delete mode 100644 core/src/test/scala/spark/BoundedMemoryCacheSuite.scala (limited to 'core') diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala deleted file mode 100644 index e8392a194f..0000000000 --- a/core/src/main/scala/spark/BoundedMemoryCache.scala +++ /dev/null @@ -1,118 +0,0 @@ -package spark - -import java.util.LinkedHashMap - -/** - * An implementation of Cache that estimates the sizes of its entries and attempts to limit its - * total memory usage to a fraction of the JVM heap. Objects' sizes are estimated using - * SizeEstimator, which has limitations; most notably, we will overestimate total memory used if - * some cache entries have pointers to a shared object. Nonetheless, this Cache should work well - * when most of the space is used by arrays of primitives or of simple classes. - */ -private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging { - logInfo("BoundedMemoryCache.maxBytes = " + maxBytes) - - def this() { - this(BoundedMemoryCache.getMaxBytes) - } - - private var currentBytes = 0L - private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true) - - override def get(datasetId: Any, partition: Int): Any = { - synchronized { - val entry = map.get((datasetId, partition)) - if (entry != null) { - entry.value - } else { - null - } - } - } - - override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = { - val key = (datasetId, partition) - logInfo("Asked to add key " + key) - val size = estimateValueSize(key, value) - synchronized { - if (size > getCapacity) { - return CachePutFailure() - } else if (ensureFreeSpace(datasetId, size)) { - logInfo("Adding key " + key) - map.put(key, new Entry(value, size)) - currentBytes += size - logInfo("Number of entries is now " + map.size) - return CachePutSuccess(size) - } else { - logInfo("Didn't add key " + key + " because we would have evicted part of same dataset") - return CachePutFailure() - } - } - } - - override def getCapacity: Long = maxBytes - - /** - * Estimate sizeOf 'value' - */ - private def estimateValueSize(key: (Any, Int), value: Any) = { - val startTime = System.currentTimeMillis - val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef]) - val timeTaken = System.currentTimeMillis - startTime - logInfo("Estimated size for key %s is %d".format(key, size)) - logInfo("Size estimation for key %s took %d ms".format(key, timeTaken)) - size - } - - /** - * Remove least recently used entries from the map until at least space bytes are free, in order - * to make space for a partition from the given dataset ID. If this cannot be done without - * evicting other data from the same dataset, returns false; otherwise, returns true. Assumes - * that a lock is held on the BoundedMemoryCache. - */ - private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = { - logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format( - datasetId, space, currentBytes, maxBytes)) - val iter = map.entrySet.iterator // Will give entries in LRU order - while (maxBytes - currentBytes < space && iter.hasNext) { - val mapEntry = iter.next() - val (entryDatasetId, entryPartition) = mapEntry.getKey - if (entryDatasetId == datasetId) { - // Cannot make space without removing part of the same dataset, or a more recently used one - return false - } - reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue) - currentBytes -= mapEntry.getValue.size - iter.remove() - } - return true - } - - protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { - logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size)) - // TODO: remove BoundedMemoryCache - - val (keySpaceId, innerDatasetId) = datasetId.asInstanceOf[(Any, Any)] - innerDatasetId match { - case rddId: Int => - SparkEnv.get.cacheTracker.dropEntry(rddId, partition) - case broadcastUUID: java.util.UUID => - // TODO: Maybe something should be done if the broadcasted variable falls out of cache - case _ => - } - } -} - -// An entry in our map; stores a cached object and its size in bytes -private[spark] case class Entry(value: Any, size: Long) - -private[spark] object BoundedMemoryCache { - /** - * Get maximum cache capacity from system configuration - */ - def getMaxBytes: Long = { - val memoryFractionToUse = System.getProperty("spark.boundedMemoryCache.memoryFraction", "0.66").toDouble - (Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong - } -} - diff --git a/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala b/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala deleted file mode 100644 index 37cafd1e8e..0000000000 --- a/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala +++ /dev/null @@ -1,58 +0,0 @@ -package spark - -import org.scalatest.FunSuite -import org.scalatest.PrivateMethodTester -import org.scalatest.matchers.ShouldMatchers - -// TODO: Replace this with a test of MemoryStore -class BoundedMemoryCacheSuite extends FunSuite with PrivateMethodTester with ShouldMatchers { - test("constructor test") { - val cache = new BoundedMemoryCache(60) - expect(60)(cache.getCapacity) - } - - test("caching") { - // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case - val oldArch = System.setProperty("os.arch", "amd64") - val oldOops = System.setProperty("spark.test.useCompressedOops", "true") - val initialize = PrivateMethod[Unit]('initialize) - SizeEstimator invokePrivate initialize() - - val cache = new BoundedMemoryCache(60) { - //TODO sorry about this, but there is not better way how to skip 'cacheTracker.dropEntry' - override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { - logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size)) - } - } - - // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length - // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6. - // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html - // Work around to check for either. - - //should be OK - cache.put("1", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48))) - - //we cannot add this to cache (there is not enough space in cache) & we cannot evict the only value from - //cache because it's from the same dataset - expect(CachePutFailure())(cache.put("1", 1, "Meh")) - - //should be OK, dataset '1' can be evicted from cache - cache.put("2", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48))) - - //should fail, cache should obey it's capacity - expect(CachePutFailure())(cache.put("3", 0, "Very_long_and_useless_string")) - - if (oldArch != null) { - System.setProperty("os.arch", oldArch) - } else { - System.clearProperty("os.arch") - } - - if (oldOops != null) { - System.setProperty("spark.test.useCompressedOops", oldOops) - } else { - System.clearProperty("spark.test.useCompressedOops") - } - } -} -- cgit v1.2.3 From 55c66d365f76f3e5ecc6b850ba81c84b320f6772 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 7 Jan 2013 15:19:33 -0800 Subject: Use a dummy string class in Size Estimator tests to make it resistant to jdk versions --- core/src/test/scala/spark/SizeEstimatorSuite.scala | 33 ++++++++++++++-------- 1 file changed, 21 insertions(+), 12 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala index 17f366212b..bf3b2e1eed 100644 --- a/core/src/test/scala/spark/SizeEstimatorSuite.scala +++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala @@ -20,6 +20,15 @@ class DummyClass4(val d: DummyClass3) { val x: Int = 0 } +object DummyString { + def apply(str: String) : DummyString = new DummyString(str.toArray) +} +class DummyString(val arr: Array[Char]) { + override val hashCode: Int = 0 + // JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f + @transient val hash32: Int = 0 +} + class SizeEstimatorSuite extends FunSuite with BeforeAndAfterAll with PrivateMethodTester with ShouldMatchers { @@ -50,10 +59,10 @@ class SizeEstimatorSuite // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html // Work around to check for either. test("strings") { - SizeEstimator.estimate("") should (equal (48) or equal (40)) - SizeEstimator.estimate("a") should (equal (56) or equal (48)) - SizeEstimator.estimate("ab") should (equal (56) or equal (48)) - SizeEstimator.estimate("abcdefgh") should (equal(64) or equal(56)) + SizeEstimator.estimate(DummyString("")) should (equal (48) or equal (40)) + SizeEstimator.estimate(DummyString("a")) should (equal (56) or equal (48)) + SizeEstimator.estimate(DummyString("ab")) should (equal (56) or equal (48)) + SizeEstimator.estimate(DummyString("abcdefgh")) should (equal(64) or equal(56)) } test("primitive arrays") { @@ -105,10 +114,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - expect(40)(SizeEstimator.estimate("")) - expect(48)(SizeEstimator.estimate("a")) - expect(48)(SizeEstimator.estimate("ab")) - expect(56)(SizeEstimator.estimate("abcdefgh")) + expect(40)(SizeEstimator.estimate(DummyString(""))) + expect(48)(SizeEstimator.estimate(DummyString("a"))) + expect(48)(SizeEstimator.estimate(DummyString("ab"))) + expect(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) resetOrClear("os.arch", arch) } @@ -124,10 +133,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - SizeEstimator.estimate("") should (equal (64) or equal (56)) - SizeEstimator.estimate("a") should (equal (72) or equal (64)) - SizeEstimator.estimate("ab") should (equal (72) or equal (64)) - SizeEstimator.estimate("abcdefgh") should (equal (80) or equal (72)) + SizeEstimator.estimate(DummyString("")) should (equal (64) or equal (56)) + SizeEstimator.estimate(DummyString("a")) should (equal (72) or equal (64)) + SizeEstimator.estimate(DummyString("ab")) should (equal (72) or equal (64)) + SizeEstimator.estimate(DummyString("abcdefgh")) should (equal (80) or equal (72)) resetOrClear("os.arch", arch) resetOrClear("spark.test.useCompressedOops", oops) -- cgit v1.2.3 From 4719e6d8fe6d93734f5bbe6c91dcc4616c1ed317 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 16:06:07 -0800 Subject: Changed locations for unit test logs. --- bagel/src/test/resources/log4j.properties | 4 ++-- core/src/test/resources/log4j.properties | 4 ++-- repl/src/test/resources/log4j.properties | 4 ++-- streaming/src/test/resources/log4j.properties | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) (limited to 'core') diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties index 4c99e450bc..83d05cab2f 100644 --- a/bagel/src/test/resources/log4j.properties +++ b/bagel/src/test/resources/log4j.properties @@ -1,8 +1,8 @@ -# Set everything to be logged to the console +# Set everything to be logged to the file bagel/target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false -log4j.appender.file.file=spark-tests.log +log4j.appender.file.file=bagel/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index 5ed388e91b..6ec89c0184 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -1,8 +1,8 @@ -# Set everything to be logged to the file spark-tests.log +# Set everything to be logged to the file core/target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false -log4j.appender.file.file=spark-tests.log +log4j.appender.file.file=core/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties index 4c99e450bc..cfb1a390e6 100644 --- a/repl/src/test/resources/log4j.properties +++ b/repl/src/test/resources/log4j.properties @@ -1,8 +1,8 @@ -# Set everything to be logged to the console +# Set everything to be logged to the repl/target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false -log4j.appender.file.file=spark-tests.log +log4j.appender.file.file=repl/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 33bafebaab..edfa1243fa 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -1,8 +1,8 @@ -# Set everything to be logged to the file streaming-tests.log +# Set everything to be logged to the file streaming/target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false -log4j.appender.file.file=streaming-tests.log +log4j.appender.file.file=streaming/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n -- cgit v1.2.3 From b1336e2fe458b92dcf60dcd249c41c7bdcc8be6d Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 7 Jan 2013 17:00:32 -0800 Subject: Update expected size of strings to match our dummy string class --- core/src/test/scala/spark/SizeEstimatorSuite.scala | 31 +++++++++------------- 1 file changed, 13 insertions(+), 18 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala index bf3b2e1eed..e235ef2f67 100644 --- a/core/src/test/scala/spark/SizeEstimatorSuite.scala +++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala @@ -3,7 +3,6 @@ package spark import org.scalatest.FunSuite import org.scalatest.BeforeAndAfterAll import org.scalatest.PrivateMethodTester -import org.scalatest.matchers.ShouldMatchers class DummyClass1 {} @@ -30,7 +29,7 @@ class DummyString(val arr: Array[Char]) { } class SizeEstimatorSuite - extends FunSuite with BeforeAndAfterAll with PrivateMethodTester with ShouldMatchers { + extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { var oldArch: String = _ var oldOops: String = _ @@ -54,15 +53,13 @@ class SizeEstimatorSuite expect(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3))) } - // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length. - // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6. - // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html - // Work around to check for either. + // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors + // (Sun vs IBM). Use a DummyString class to make tests deterministic. test("strings") { - SizeEstimator.estimate(DummyString("")) should (equal (48) or equal (40)) - SizeEstimator.estimate(DummyString("a")) should (equal (56) or equal (48)) - SizeEstimator.estimate(DummyString("ab")) should (equal (56) or equal (48)) - SizeEstimator.estimate(DummyString("abcdefgh")) should (equal(64) or equal(56)) + expect(40)(SizeEstimator.estimate(DummyString(""))) + expect(48)(SizeEstimator.estimate(DummyString("a"))) + expect(48)(SizeEstimator.estimate(DummyString("ab"))) + expect(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) } test("primitive arrays") { @@ -122,10 +119,8 @@ class SizeEstimatorSuite resetOrClear("os.arch", arch) } - // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length. - // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6. - // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html - // Work around to check for either. + // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors + // (Sun vs IBM). Use a DummyString class to make tests deterministic. test("64-bit arch with no compressed oops") { val arch = System.setProperty("os.arch", "amd64") val oops = System.setProperty("spark.test.useCompressedOops", "false") @@ -133,10 +128,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - SizeEstimator.estimate(DummyString("")) should (equal (64) or equal (56)) - SizeEstimator.estimate(DummyString("a")) should (equal (72) or equal (64)) - SizeEstimator.estimate(DummyString("ab")) should (equal (72) or equal (64)) - SizeEstimator.estimate(DummyString("abcdefgh")) should (equal (80) or equal (72)) + expect(56)(SizeEstimator.estimate(DummyString(""))) + expect(64)(SizeEstimator.estimate(DummyString("a"))) + expect(64)(SizeEstimator.estimate(DummyString("ab"))) + expect(72)(SizeEstimator.estimate(DummyString("abcdefgh"))) resetOrClear("os.arch", arch) resetOrClear("spark.test.useCompressedOops", oops) -- cgit v1.2.3 From 4bbe07e5ece81fa874d2412bcc165179313a7619 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 7 Jan 2013 17:46:22 -0800 Subject: Activate hadoop1 profile by default for maven builds --- bagel/pom.xml | 3 +++ core/pom.xml | 5 ++++- examples/pom.xml | 3 +++ pom.xml | 3 +++ repl-bin/pom.xml | 3 +++ repl/pom.xml | 3 +++ 6 files changed, 19 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/bagel/pom.xml b/bagel/pom.xml index a8256a6e8b..4ca643bbb7 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -45,6 +45,9 @@ hadoop1 + + true + org.spark-project diff --git a/core/pom.xml b/core/pom.xml index ae52c20657..cd789a7db0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -159,6 +159,9 @@ hadoop1 + + true + org.apache.hadoop @@ -267,4 +270,4 @@ - \ No newline at end of file + diff --git a/examples/pom.xml b/examples/pom.xml index 782c026d73..9e638c8284 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -45,6 +45,9 @@ hadoop1 + + true + org.spark-project diff --git a/pom.xml b/pom.xml index fe5b1d0ee4..0e2d93c170 100644 --- a/pom.xml +++ b/pom.xml @@ -481,6 +481,9 @@ hadoop1 + + true + 1 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index 0667b71cc7..aa9895eda2 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -70,6 +70,9 @@ hadoop1 + + true + hadoop1 diff --git a/repl/pom.xml b/repl/pom.xml index 114e3e9932..ba7a051310 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -72,6 +72,9 @@ hadoop1 + + true + hadoop1 -- cgit v1.2.3 From c41042c816c2d6299aa7d93529b7c39db5d5c03a Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Wed, 26 Dec 2012 15:52:51 -0800 Subject: Log preferred hosts --- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index cf4aae03a7..dda7a6c64a 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -201,7 +201,9 @@ private[spark] class TaskSetManager( val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) "preferred" else "non-preferred" + val prefStr = if (preferred) "preferred" else + "non-preferred, not one of " + + task.preferredLocations.mkString(", ") logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( taskSet.id, index, taskId, slaveId, host, prefStr)) // Do various bookkeeping -- cgit v1.2.3 From 4725b0f6439337c7a0f5f6fc7034c6f6b9488ae9 Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Mon, 7 Jan 2013 20:07:08 -0800 Subject: Fixing if/else coding style for preferred hosts logging --- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index dda7a6c64a..a842afcdeb 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -201,9 +201,8 @@ private[spark] class TaskSetManager( val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) "preferred" else - "non-preferred, not one of " + - task.preferredLocations.mkString(", ") + val prefStr = if (preferred) "preferred" + else "non-preferred, not one of " + task.preferredLocations.mkString(", ") logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( taskSet.id, index, taskId, slaveId, host, prefStr)) // Do various bookkeeping -- cgit v1.2.3 From f7adb382ace7f54c5093bf90574b3f9dd0d35534 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Tue, 8 Jan 2013 03:19:43 -0800 Subject: Activate hadoop1 if property hadoop is missing. hadoop2 can be activated now by using -Dhadoop -Phadoop2. --- bagel/pom.xml | 4 +++- core/pom.xml | 4 +++- examples/pom.xml | 4 +++- pom.xml | 4 +++- repl-bin/pom.xml | 4 +++- repl/pom.xml | 4 +++- 6 files changed, 18 insertions(+), 6 deletions(-) (limited to 'core') diff --git a/bagel/pom.xml b/bagel/pom.xml index 4ca643bbb7..85b2077026 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -46,7 +46,9 @@ hadoop1 - true + + !hadoop + diff --git a/core/pom.xml b/core/pom.xml index cd789a7db0..005d8fe498 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -160,7 +160,9 @@ hadoop1 - true + + !hadoop + diff --git a/examples/pom.xml b/examples/pom.xml index 9e638c8284..3f738a3f8c 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -46,7 +46,9 @@ hadoop1 - true + + !hadoop + diff --git a/pom.xml b/pom.xml index 0e2d93c170..ea5b9c9d05 100644 --- a/pom.xml +++ b/pom.xml @@ -482,7 +482,9 @@ hadoop1 - true + + !hadoop + 1 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index aa9895eda2..fecb01f3cd 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -71,7 +71,9 @@ hadoop1 - true + + !hadoop + hadoop1 diff --git a/repl/pom.xml b/repl/pom.xml index ba7a051310..04b2c35beb 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -73,7 +73,9 @@ hadoop1 - true + + !hadoop + hadoop1 -- cgit v1.2.3 From e4cb72da8a5428c6b9097e92ddbdf4ceee087b85 Mon Sep 17 00:00:00 2001 From: shane-huang Date: Tue, 8 Jan 2013 22:40:58 +0800 Subject: Fix an issue in ConnectionManager where sendingMessage may create too many unnecessary SendingConnections. --- core/src/main/scala/spark/network/Connection.scala | 7 +++++-- .../main/scala/spark/network/ConnectionManager.scala | 17 +++++++++-------- .../scala/spark/network/ConnectionManagerTest.scala | 18 +++++++++--------- 3 files changed, 23 insertions(+), 19 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 80262ab7b4..95096fd0ba 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -135,8 +135,11 @@ extends Connection(SocketChannel.open, selector_) { val chunk = message.getChunkForSending(defaultChunkSize) if (chunk.isDefined) { messages += message // this is probably incorrect, it wont work as fifo - if (!message.started) logDebug("Starting to send [" + message + "]") - message.started = true + if (!message.started) { + logDebug("Starting to send [" + message + "]") + message.started = true + message.startTime = System.currentTimeMillis + } return chunk } else { /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 642fa4b525..e7bd2d3bbd 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -43,12 +43,12 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } val selector = SelectorProvider.provider.openSelector() - val handleMessageExecutor = Executors.newFixedThreadPool(4) + val handleMessageExecutor = Executors.newFixedThreadPool(20) val serverChannel = ServerSocketChannel.open() val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val messageStatuses = new HashMap[Int, MessageStatus] - val connectionRequests = new SynchronizedQueue[SendingConnection] + val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] val sendMessageRequests = new Queue[(Message, SendingConnection)] @@ -78,11 +78,12 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def run() { try { - while(!selectorThread.isInterrupted) { - while(!connectionRequests.isEmpty) { - val sendingConnection = connectionRequests.dequeue + while(!selectorThread.isInterrupted) { + for( (connectionManagerId, sendingConnection) <- connectionRequests) { + //val sendingConnection = connectionRequests.dequeue sendingConnection.connect() addConnection(sendingConnection) + connectionRequests -= connectionManagerId } sendMessageRequests.synchronized { while(!sendMessageRequests.isEmpty) { @@ -300,8 +301,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = new SendingConnection(inetSocketAddress, selector) - connectionRequests += newConnection + val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector)) newConnection } val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) @@ -465,7 +465,7 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) + val g = Await.result(f, 10 second) if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis @@ -473,6 +473,7 @@ private[spark] object ConnectionManager { val mb = size * count / 1024.0 / 1024.0 val ms = finishTime - startTime val tput = mb * 1000.0 / ms + println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") println("--------------------------") println() } diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala index 47ceaf3c07..0e79c518e0 100644 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala @@ -13,8 +13,8 @@ import akka.util.duration._ private[spark] object ConnectionManagerTest extends Logging{ def main(args: Array[String]) { - if (args.length < 2) { - println("Usage: ConnectionManagerTest ") + if (args.length < 5) { + println("Usage: ConnectionManagerTest ") System.exit(1) } @@ -29,16 +29,16 @@ private[spark] object ConnectionManagerTest extends Logging{ /*println("Slaves")*/ /*slaves.foreach(println)*/ - - val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map( + val tasknum = args(2).toInt + val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( i => SparkEnv.get.connectionManager.id).collect() println("\nSlave ConnectionManagerIds") slaveConnManagerIds.foreach(println) println - val count = 10 + val count = args(4).toInt (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => { + val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { val connManager = SparkEnv.get.connectionManager val thisConnManagerId = connManager.id connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { @@ -46,7 +46,7 @@ private[spark] object ConnectionManagerTest extends Logging{ None }) - val size = 100 * 1024 * 1024 + val size = (args(3).toInt) * 1024 * 1024 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip @@ -56,13 +56,13 @@ private[spark] object ConnectionManagerTest extends Logging{ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) }) - val results = futures.map(f => Await.result(f, 1.second)) + val results = futures.map(f => Await.result(f, 999.second)) val finishTime = System.currentTimeMillis Thread.sleep(5000) val mb = size * results.size / 1024.0 / 1024.0 val ms = finishTime - startTime - val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" + val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" logInfo(resultStr) resultStr }).collect() -- cgit v1.2.3 From 8ac0f35be42765fcd6f02dcf0f070f2ef2377a85 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 8 Jan 2013 09:57:45 -0600 Subject: Add JavaRDDLike.keyBy. --- core/src/main/scala/spark/api/java/JavaRDDLike.scala | 8 ++++++++ core/src/test/scala/spark/JavaAPISuite.java | 12 ++++++++++++ 2 files changed, 20 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 81d3a94466..d15f6dd02f 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -298,4 +298,12 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Save this RDD as a SequenceFile of serialized objects. */ def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path) + + /** + * Creates tuples of the elements in this RDD by applying `f`. + */ + def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = { + implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] + JavaPairRDD.fromRDD(rdd.keyBy(f)) + } } diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 0817d1146c..c61913fc82 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -629,4 +629,16 @@ public class JavaAPISuite implements Serializable { floatAccum.setValue(5.0f); Assert.assertEquals((Float) 5.0f, floatAccum.value()); } + + @Test + public void keyBy() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); + List> s = rdd.keyBy(new Function() { + public String call(Integer t) throws Exception { + return t.toString(); + } + }).collect(); + Assert.assertEquals(new Tuple2("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2("2", 2), s.get(1)); + } } -- cgit v1.2.3 From b57dd0f16024a82dfc223e69528b9908b931f068 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 8 Jan 2013 16:04:41 -0800 Subject: Add mapPartitionsWithSplit() to PySpark. --- .../main/scala/spark/api/python/PythonRDD.scala | 5 ++++ docs/python-programming-guide.md | 1 - python/pyspark/rdd.py | 33 ++++++++++++++-------- python/pyspark/worker.py | 4 ++- 4 files changed, 30 insertions(+), 13 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 79d824d494..f431ef28d3 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -65,6 +65,9 @@ private[spark] class PythonRDD[T: ClassManifest]( SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) val dOut = new DataOutputStream(proc.getOutputStream) + // Split index + dOut.writeInt(split.index) + // Broadcast variables dOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { dOut.writeLong(broadcast.id) @@ -72,10 +75,12 @@ private[spark] class PythonRDD[T: ClassManifest]( dOut.write(broadcast.value) dOut.flush() } + // Serialized user code for (elem <- command) { out.println(elem) } out.flush() + // Data values for (elem <- parent.iterator(split, context)) { PythonRDD.writeAsPickle(elem, dOut) } diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index d963551296..78ef310a00 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -19,7 +19,6 @@ There are a few key differences between the Python and Scala APIs: - Accumulators - Special functions on RDDs of doubles, such as `mean` and `stdev` - `lookup` - - `mapPartitionsWithSplit` - `persist` at storage levels other than `MEMORY_ONLY` - `sample` - `sort` diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4ba417b2a2..1d36da42b0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -55,7 +55,7 @@ class RDD(object): """ Return a new RDD containing the distinct elements in this RDD. """ - def func(iterator): return imap(f, iterator) + def func(split, iterator): return imap(f, iterator) return PipelinedRDD(self, func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -69,8 +69,8 @@ class RDD(object): >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ - def func(iterator): return chain.from_iterable(imap(f, iterator)) - return self.mapPartitions(func, preservesPartitioning) + def func(s, iterator): return chain.from_iterable(imap(f, iterator)) + return self.mapPartitionsWithSplit(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): """ @@ -81,9 +81,20 @@ class RDD(object): >>> rdd.mapPartitions(f).collect() [3, 7] """ - return PipelinedRDD(self, f, preservesPartitioning) + def func(s, iterator): return f(iterator) + return self.mapPartitionsWithSplit(func) + + def mapPartitionsWithSplit(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition of this RDD, + while tracking the index of the original partition. - # TODO: mapPartitionsWithSplit + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(splitIndex, iterator): yield splitIndex + >>> rdd.mapPartitionsWithSplit(f).sum() + 6 + """ + return PipelinedRDD(self, f, preservesPartitioning) def filter(self, f): """ @@ -362,7 +373,7 @@ class RDD(object): >>> ''.join(input(glob(tempFile.name + "/part-0000*"))) '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' """ - def func(iterator): + def func(split, iterator): return (str(x).encode("utf-8") for x in iterator) keyed = PipelinedRDD(self, func) keyed._bypass_serializer = True @@ -500,7 +511,7 @@ class RDD(object): # Transferring O(n) objects to Java is too expensive. Instead, we'll # form the hash buckets in Python, transferring O(numSplits) objects # to Java. Each object is a (splitNumber, [objects]) pair. - def add_shuffle_key(iterator): + def add_shuffle_key(split, iterator): buckets = defaultdict(list) for (k, v) in iterator: buckets[hashFunc(k) % numSplits].append((k, v)) @@ -653,8 +664,8 @@ class PipelinedRDD(RDD): def __init__(self, prev, func, preservesPartitioning=False): if isinstance(prev, PipelinedRDD) and not prev.is_cached: prev_func = prev.func - def pipeline_func(iterator): - return func(prev_func(iterator)) + def pipeline_func(split, iterator): + return func(split, prev_func(split, iterator)) self.func = pipeline_func self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning @@ -677,8 +688,8 @@ class PipelinedRDD(RDD): if not self._bypass_serializer and self.ctx.batchSize != 1: oldfunc = self.func batchSize = self.ctx.batchSize - def batched_func(iterator): - return batched(oldfunc(iterator), batchSize) + def batched_func(split, iterator): + return batched(oldfunc(split, iterator), batchSize) func = batched_func cmds = [func, self._bypass_serializer] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 9f6b507dbd..3d792bbaa2 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -21,6 +21,7 @@ def load_obj(): def main(): + split_index = read_int(sys.stdin) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): bid = read_long(sys.stdin) @@ -32,7 +33,8 @@ def main(): dumps = lambda x: x else: dumps = dump_pickle - for obj in func(read_from_pickle_file(sys.stdin)): + iterator = read_from_pickle_file(sys.stdin) + for obj in func(split_index, iterator): write_with_length(dumps(obj), old_stdout) -- cgit v1.2.3 From 9cc764f52323baa3a218ce9e301d3cc98f1e8b20 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 8 Jan 2013 22:29:57 -0800 Subject: Code style --- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index a842afcdeb..a089b71644 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -201,8 +201,11 @@ private[spark] class TaskSetManager( val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) "preferred" - else "non-preferred, not one of " + task.preferredLocations.mkString(", ") + val prefStr = if (preferred) { + "preferred" + } else { + "non-preferred, not one of " + task.preferredLocations.mkString(", ") + } logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( taskSet.id, index, taskId, slaveId, host, prefStr)) // Do various bookkeeping -- cgit v1.2.3 From 269fe018c73a0d4e12a3c881dbd3bd807e504891 Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 10:35:59 -0500 Subject: JSON object definitions --- .../src/main/scala/spark/deploy/JsonProtocol.scala | 59 ++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 core/src/main/scala/spark/deploy/JsonProtocol.scala (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala new file mode 100644 index 0000000000..dc7da85f9c --- /dev/null +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -0,0 +1,59 @@ +package spark.deploy + +import master.{JobInfo, WorkerInfo} +import spray.json._ + +/** + * spray-json helper class containing implicit conversion to json for marshalling responses + */ +private[spark] object JsonProtocol extends DefaultJsonProtocol { + import cc.spray.json._ + + implicit object WorkerInfoJsonFormat extends RootJsonWriter[WorkerInfo] { + def write(obj: WorkerInfo) = JsObject( + "id" -> JsString(obj.id), + "host" -> JsString(obj.host), + "webuiaddress" -> JsString(obj.webUiAddress), + "cores" -> JsNumber(obj.cores), + "coresused" -> JsNumber(obj.coresUsed), + "memory" -> JsNumber(obj.memory), + "memoryused" -> JsNumber(obj.memoryUsed) + ) + } + + implicit object JobInfoJsonFormat extends RootJsonWriter[JobInfo] { + def write(obj: JobInfo) = JsObject( + "starttime" -> JsNumber(obj.startTime), + "id" -> JsString(obj.id), + "name" -> JsString(obj.desc.name), + "cores" -> JsNumber(obj.desc.cores), + "user" -> JsString(obj.desc.user), + "memoryperslave" -> JsNumber(obj.desc.memoryPerSlave), + "submitdate" -> JsString(obj.submitDate.toString)) + } + + implicit object MasterStateJsonFormat extends RootJsonWriter[MasterState] { + def write(obj: MasterState) = JsObject( + "url" -> JsString("spark://" + obj.uri), + "workers" -> JsArray(obj.workers.toList.map(_.toJson)), + "cores" -> JsNumber(obj.workers.map(_.cores).sum), + "coresused" -> JsNumber(obj.workers.map(_.coresUsed).sum), + "memory" -> JsNumber(obj.workers.map(_.memory).sum), + "memoryused" -> JsNumber(obj.workers.map(_.memoryUsed).sum), + "activejobs" -> JsArray(obj.activeJobs.toList.map(_.toJson)), + "completedjobs" -> JsArray(obj.completedJobs.toList.map(_.toJson)) + ) + } + + implicit object WorkerStateJsonFormat extends RootJsonWriter[WorkerState] { + def write(obj: WorkerState) = JsObject( + "id" -> JsString(obj.workerId), + "masterurl" -> JsString(obj.masterUrl), + "masterwebuiurl" -> JsString(obj.masterWebUiUrl), + "cores" -> JsNumber(obj.cores), + "coresused" -> JsNumber(obj.coresUsed), + "memory" -> JsNumber(obj.memory), + "memoryused" -> JsNumber(obj.memoryUsed) + ) + } +} -- cgit v1.2.3 From 0da2ff102e1e8ac50059252a153a1b9b3e74b6b8 Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 10:36:56 -0500 Subject: Added url query parameter json and handler --- .../main/scala/spark/deploy/master/MasterWebUI.scala | 19 ++++++++++++++----- .../main/scala/spark/deploy/worker/WorkerWebUI.scala | 20 +++++++++++++++----- 2 files changed, 29 insertions(+), 10 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 3cdd3721f5..dfec1d1dc5 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -9,6 +9,9 @@ import cc.spray.Directives import cc.spray.directives._ import cc.spray.typeconversion.TwirlSupport._ import spark.deploy._ +import cc.spray.http.MediaTypes +import JsonProtocol._ +import cc.spray.typeconversion.SprayJsonSupport._ private[spark] class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { @@ -19,13 +22,19 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct val handler = { get { - path("") { - completeWith { + (path("") & parameters('json ?)) { + case Some(js) => val future = master ? RequestMasterState - future.map { - masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState]) + respondWithMediaType(MediaTypes.`application/json`) { ctx => + ctx.complete(future.mapTo[MasterState]) + } + case None => + completeWith { + val future = master ? RequestMasterState + future.map { + masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState]) + } } - } } ~ path("job") { parameter("jobId") { jobId => diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index d06f4884ee..a168f54ca0 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -7,7 +7,10 @@ import akka.util.Timeout import akka.util.duration._ import cc.spray.Directives import cc.spray.typeconversion.TwirlSupport._ -import spark.deploy.{WorkerState, RequestWorkerState} +import spark.deploy.{JsonProtocol, WorkerState, RequestWorkerState} +import cc.spray.http.MediaTypes +import JsonProtocol._ +import cc.spray.typeconversion.SprayJsonSupport._ private[spark] class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { @@ -18,13 +21,20 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct val handler = { get { - path("") { - completeWith{ + (path("") & parameters('json ?)) { + case Some(js) => { val future = worker ? RequestWorkerState - future.map { workerState => - spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState]) + respondWithMediaType(MediaTypes.`application/json`) { ctx => + ctx.complete(future.mapTo[WorkerState]) } } + case None => + completeWith{ + val future = worker ? RequestWorkerState + future.map { workerState => + spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState]) + } + } } ~ path("log") { parameters("jobId", "executorId", "logType") { (jobId, executorId, logType) => -- cgit v1.2.3 From bf9d9946f97782c9212420123b4a042918d7df5e Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 11:29:22 -0500 Subject: Query parameter reformatted to be more extensible and routing more robust --- core/src/main/scala/spark/deploy/master/MasterWebUI.scala | 6 +++--- core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index dfec1d1dc5..a96b55d6f3 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -22,13 +22,13 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct val handler = { get { - (path("") & parameters('json ?)) { - case Some(js) => + (path("") & parameters('format ?)) { + case Some(js) if js.equalsIgnoreCase("json") => val future = master ? RequestMasterState respondWithMediaType(MediaTypes.`application/json`) { ctx => ctx.complete(future.mapTo[MasterState]) } - case None => + case _ => completeWith { val future = master ? RequestMasterState future.map { diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index a168f54ca0..84b6c16bd6 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -21,14 +21,14 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct val handler = { get { - (path("") & parameters('json ?)) { - case Some(js) => { + (path("") & parameters('format ?)) { + case Some(js) if js.equalsIgnoreCase("json") => { val future = worker ? RequestWorkerState respondWithMediaType(MediaTypes.`application/json`) { ctx => ctx.complete(future.mapTo[WorkerState]) } } - case None => + case _ => completeWith{ val future = worker ? RequestWorkerState future.map { workerState => -- cgit v1.2.3 From 549ee388a125ac7014ae3dadfb16c582e250c654 Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 15:12:23 -0500 Subject: Removed io.spray spray-json dependency as it is not needed. --- core/src/main/scala/spark/deploy/JsonProtocol.scala | 4 +--- project/SparkBuild.scala | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala index dc7da85f9c..f14f804b3a 100644 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -1,14 +1,12 @@ package spark.deploy import master.{JobInfo, WorkerInfo} -import spray.json._ +import cc.spray.json._ /** * spray-json helper class containing implicit conversion to json for marshalling responses */ private[spark] object JsonProtocol extends DefaultJsonProtocol { - import cc.spray.json._ - implicit object WorkerInfoJsonFormat extends RootJsonWriter[WorkerInfo] { def write(obj: WorkerInfo) = JsObject( "id" -> JsString(obj.id), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f2b79d9ed8..c63efbdd2a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -134,7 +134,6 @@ object SparkBuild extends Build { "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" %% "spray-json" % "1.1.1", - "io.spray" %% "spray-json" % "1.2.3", "org.apache.mesos" % "mesos" % "0.9.0-incubating" ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } -- cgit v1.2.3 From e3861ae3953d7cab66160833688c8baf84e835ad Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Wed, 9 Jan 2013 17:03:25 -0600 Subject: Provide and expose a default Hadoop Configuration. Any "hadoop.*" system properties will be passed along into configuration. --- core/src/main/scala/spark/SparkContext.scala | 18 ++++++++++++++---- .../main/scala/spark/api/java/JavaSparkContext.scala | 7 +++++++ 2 files changed, 21 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bbf8272eb3..36e0938854 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -187,6 +187,18 @@ class SparkContext( private var dagScheduler = new DAGScheduler(taskScheduler) + /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ + val hadoopConfiguration = { + val conf = new Configuration() + // Copy any "hadoop.foo=bar" system properties into conf as "foo=bar" + for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("hadoop.")) { + conf.set(key.substring("hadoop.".length), System.getProperty(key)) + } + val bufferSize = System.getProperty("spark.buffer.size", "65536") + conf.set("io.file.buffer.size", bufferSize) + conf + } + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ @@ -231,10 +243,8 @@ class SparkContext( valueClass: Class[V], minSplits: Int = defaultMinSplits ) : RDD[(K, V)] = { - val conf = new JobConf() + val conf = new JobConf(hadoopConfiguration) FileInputFormat.setInputPaths(conf, path) - val bufferSize = System.getProperty("spark.buffer.size", "65536") - conf.set("io.file.buffer.size", bufferSize) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) } @@ -276,7 +286,7 @@ class SparkContext( fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], vm.erasure.asInstanceOf[Class[V]], - new Configuration) + hadoopConfiguration) } /** diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index 88ab2846be..12e2a0bdac 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -355,6 +355,13 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def clearFiles() { sc.clearFiles() } + + /** + * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. + */ + def hadoopConfiguration() { + sc.hadoopConfiguration + } } object JavaSparkContext { -- cgit v1.2.3 From 9930a95d217045c4c22c2575080a03e4b0fd2426 Mon Sep 17 00:00:00 2001 From: shane-huang Date: Thu, 10 Jan 2013 20:09:34 +0800 Subject: Modified Patch according to comments --- core/src/main/scala/spark/network/Connection.scala | 8 ++++---- .../main/scala/spark/network/ConnectionManager.scala | 9 ++++----- .../scala/spark/network/ConnectionManagerTest.scala | 20 ++++++++++++++------ 3 files changed, 22 insertions(+), 15 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 95096fd0ba..c193bf7c8d 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -136,10 +136,10 @@ extends Connection(SocketChannel.open, selector_) { if (chunk.isDefined) { messages += message // this is probably incorrect, it wont work as fifo if (!message.started) { - logDebug("Starting to send [" + message + "]") - message.started = true - message.startTime = System.currentTimeMillis - } + logDebug("Starting to send [" + message + "]") + message.started = true + message.startTime = System.currentTimeMillis + } return chunk } else { /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index e7bd2d3bbd..36c01ad629 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -43,12 +43,12 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } val selector = SelectorProvider.provider.openSelector() - val handleMessageExecutor = Executors.newFixedThreadPool(20) + val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt) val serverChannel = ServerSocketChannel.open() val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val messageStatuses = new HashMap[Int, MessageStatus] - val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] + val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] val sendMessageRequests = new Queue[(Message, SendingConnection)] @@ -78,9 +78,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def run() { try { - while(!selectorThread.isInterrupted) { + while(!selectorThread.isInterrupted) { for( (connectionManagerId, sendingConnection) <- connectionRequests) { - //val sendingConnection = connectionRequests.dequeue sendingConnection.connect() addConnection(sendingConnection) connectionRequests -= connectionManagerId @@ -465,7 +464,7 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 10 second) + val g = Await.result(f, 1 second) if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala index 0e79c518e0..533e4610f3 100644 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala @@ -13,8 +13,14 @@ import akka.util.duration._ private[spark] object ConnectionManagerTest extends Logging{ def main(args: Array[String]) { - if (args.length < 5) { - println("Usage: ConnectionManagerTest ") + // - the master URL + // - a list slaves to run connectionTest on + //[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts + //[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10 + //[count] - how many times to run, default is 3 + //[await time in seconds] : await time (in seconds), default is 600 + if (args.length < 2) { + println("Usage: ConnectionManagerTest [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ") System.exit(1) } @@ -29,14 +35,17 @@ private[spark] object ConnectionManagerTest extends Logging{ /*println("Slaves")*/ /*slaves.foreach(println)*/ - val tasknum = args(2).toInt + val tasknum = if (args.length > 2) args(2).toInt else slaves.length + val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 + val count = if (args.length > 4) args(4).toInt else 3 + val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second + println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime) val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( i => SparkEnv.get.connectionManager.id).collect() println("\nSlave ConnectionManagerIds") slaveConnManagerIds.foreach(println) println - val count = args(4).toInt (0 until count).foreach(i => { val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { val connManager = SparkEnv.get.connectionManager @@ -46,7 +55,6 @@ private[spark] object ConnectionManagerTest extends Logging{ None }) - val size = (args(3).toInt) * 1024 * 1024 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip @@ -56,7 +64,7 @@ private[spark] object ConnectionManagerTest extends Logging{ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) }) - val results = futures.map(f => Await.result(f, 999.second)) + val results = futures.map(f => Await.result(f, awaitTime)) val finishTime = System.currentTimeMillis Thread.sleep(5000) -- cgit v1.2.3 From b15e8512793475eaeda7225a259db8aacd600741 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Thu, 10 Jan 2013 10:55:41 -0600 Subject: Check for AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY environment variables. For custom properties, use "spark.hadoop.*" as a prefix instead of just "hadoop.*". --- core/src/main/scala/spark/SparkContext.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 36e0938854..7b11955f1e 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -190,9 +190,16 @@ class SparkContext( /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = { val conf = new Configuration() - // Copy any "hadoop.foo=bar" system properties into conf as "foo=bar" - for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("hadoop.")) { - conf.set(key.substring("hadoop.".length), System.getProperty(key)) + // Explicitly check for S3 environment variables + if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { + conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + } + // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" + for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("spark.hadoop.")) { + conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) } val bufferSize = System.getProperty("spark.buffer.size", "65536") conf.set("io.file.buffer.size", bufferSize) -- cgit v1.2.3 From d1864052c58ff1e58980729f7ccf00e630f815b9 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Thu, 10 Jan 2013 12:16:26 -0600 Subject: Fix invalid asInstanceOf cast. --- core/src/main/scala/spark/SparkContext.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 7b11955f1e..d2a5b4757a 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -7,6 +7,7 @@ import java.net.{URI, URLClassLoader} import scala.collection.Map import scala.collection.generic.Growable import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.JavaConversions._ import akka.actor.Actor import akka.actor.Actor._ @@ -198,7 +199,7 @@ class SparkContext( conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" - for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("spark.hadoop.")) { + for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) { conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) } val bufferSize = System.getProperty("spark.buffer.size", "65536") -- cgit v1.2.3 From bd336f5f406386c929f2d1f9aecd7d5190a1a087 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 10 Jan 2013 17:13:04 -0800 Subject: Changed CoGroupRDD's hash map from Scala to Java. --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index de0d9fad88..2e051c81c8 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -1,7 +1,8 @@ package spark.rdd +import java.util.{HashMap => JHashMap} import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -71,7 +72,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size - val map = new HashMap[K, Seq[ArrayBuffer[Any]]] + val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] def getSeq(k: K): Seq[ArrayBuffer[Any]] = { map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) } -- cgit v1.2.3 From 2e914d99835487e867cac6add8be1dbd80dc693f Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 10 Jan 2013 19:13:08 -0800 Subject: Formatting --- core/src/main/scala/spark/deploy/master/MasterWebUI.scala | 5 +++-- core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index a96b55d6f3..580014ef3f 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -8,11 +8,12 @@ import akka.util.duration._ import cc.spray.Directives import cc.spray.directives._ import cc.spray.typeconversion.TwirlSupport._ -import spark.deploy._ import cc.spray.http.MediaTypes -import JsonProtocol._ import cc.spray.typeconversion.SprayJsonSupport._ +import spark.deploy._ +import spark.deploy.JsonProtocol._ + private[spark] class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/master/webui" diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index 84b6c16bd6..f9489d99fc 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -7,11 +7,12 @@ import akka.util.Timeout import akka.util.duration._ import cc.spray.Directives import cc.spray.typeconversion.TwirlSupport._ -import spark.deploy.{JsonProtocol, WorkerState, RequestWorkerState} import cc.spray.http.MediaTypes -import JsonProtocol._ import cc.spray.typeconversion.SprayJsonSupport._ +import spark.deploy.{WorkerState, RequestWorkerState} +import spark.deploy.JsonProtocol._ + private[spark] class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/worker/webui" -- cgit v1.2.3 From 92625223066a5c28553d7710c6b14af56f64b560 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 10 Jan 2013 22:07:34 -0800 Subject: Activate hadoop2 profile in pom.xml with -Dhadoop=2 --- bagel/pom.xml | 6 ++++++ core/pom.xml | 6 ++++++ examples/pom.xml | 6 ++++++ pom.xml | 6 ++++++ repl-bin/pom.xml | 6 ++++++ repl/pom.xml | 6 ++++++ 6 files changed, 36 insertions(+) (limited to 'core') diff --git a/bagel/pom.xml b/bagel/pom.xml index 85b2077026..c3461fb889 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -77,6 +77,12 @@ hadoop2 + + + hadoop + 2 + + org.spark-project diff --git a/core/pom.xml b/core/pom.xml index 005d8fe498..c8ff625774 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -216,6 +216,12 @@ hadoop2 + + + hadoop + 2 + + org.apache.hadoop diff --git a/examples/pom.xml b/examples/pom.xml index 3f738a3f8c..d0b1e97747 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -77,6 +77,12 @@ hadoop2 + + + hadoop + 2 + + org.spark-project diff --git a/pom.xml b/pom.xml index ea5b9c9d05..ae87813d4e 100644 --- a/pom.xml +++ b/pom.xml @@ -502,6 +502,12 @@ hadoop2 + + + hadoop + 2 + + 2 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index fecb01f3cd..54ae20659e 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -115,6 +115,12 @@ hadoop2 + + + hadoop + 2 + + hadoop2 diff --git a/repl/pom.xml b/repl/pom.xml index 04b2c35beb..3e979b93a6 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -121,6 +121,12 @@ hadoop2 + + + hadoop + 2 + + hadoop2 -- cgit v1.2.3 From 3e6519a36e354f3623c5b968efe5217c7fcb242f Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 11 Jan 2013 11:24:20 -0600 Subject: Use hadoopConfiguration for default JobConf in PairRDDFunctions. --- core/src/main/scala/spark/PairRDDFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index ce48cea903..51c15837c4 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -557,7 +557,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf) { + conf: JobConf = new JobConf(self.context.hadoopConfiguration)) { conf.setOutputKeyClass(keyClass) conf.setOutputValueClass(valueClass) // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug -- cgit v1.2.3 From 5c7a1272198c88a90a843bbda0c1424f92b7c12e Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 11 Jan 2013 11:25:11 -0600 Subject: Pass a new Configuration that wraps the default hadoopConfiguration. --- core/src/main/scala/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d2a5b4757a..f6b98c41bc 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -294,7 +294,7 @@ class SparkContext( fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], vm.erasure.asInstanceOf[Class[V]], - hadoopConfiguration) + new Configuration(hadoopConfiguration)) } /** -- cgit v1.2.3 From c063e8777ebaeb04056889064e9264edc019edbd Mon Sep 17 00:00:00 2001 From: Tyson Date: Fri, 11 Jan 2013 14:57:38 -0500 Subject: Added implicit json writers for JobDescription and ExecutorRunner --- .../src/main/scala/spark/deploy/JsonProtocol.scala | 23 +++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala index f14f804b3a..732fa08064 100644 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -1,6 +1,7 @@ package spark.deploy import master.{JobInfo, WorkerInfo} +import worker.ExecutorRunner import cc.spray.json._ /** @@ -30,6 +31,24 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { "submitdate" -> JsString(obj.submitDate.toString)) } + implicit object JobDescriptionJsonFormat extends RootJsonWriter[JobDescription] { + def write(obj: JobDescription) = JsObject( + "name" -> JsString(obj.name), + "cores" -> JsNumber(obj.cores), + "memoryperslave" -> JsNumber(obj.memoryPerSlave), + "user" -> JsString(obj.user) + ) + } + + implicit object ExecutorRunnerJsonFormat extends RootJsonWriter[ExecutorRunner] { + def write(obj: ExecutorRunner) = JsObject( + "id" -> JsNumber(obj.execId), + "memory" -> JsNumber(obj.memory), + "jobid" -> JsString(obj.jobId), + "jobdesc" -> obj.jobDesc.toJson.asJsObject + ) + } + implicit object MasterStateJsonFormat extends RootJsonWriter[MasterState] { def write(obj: MasterState) = JsObject( "url" -> JsString("spark://" + obj.uri), @@ -51,7 +70,9 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { "cores" -> JsNumber(obj.cores), "coresused" -> JsNumber(obj.coresUsed), "memory" -> JsNumber(obj.memory), - "memoryused" -> JsNumber(obj.memoryUsed) + "memoryused" -> JsNumber(obj.memoryUsed), + "executors" -> JsArray(obj.executors.toList.map(_.toJson)), + "finishedexecutors" -> JsArray(obj.finishedExecutors.toList.map(_.toJson)) ) } } -- cgit v1.2.3 From 1731f1fed4f1369662b1a9fde850a3dcba738a59 Mon Sep 17 00:00:00 2001 From: Tyson Date: Fri, 11 Jan 2013 15:01:43 -0500 Subject: Added an optional format parameter for individual job queries and optimized the jobId query --- .../scala/spark/deploy/master/MasterWebUI.scala | 38 +++++++++++++++------- 1 file changed, 27 insertions(+), 11 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 580014ef3f..458ee2d665 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -38,20 +38,36 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct } } ~ path("job") { - parameter("jobId") { jobId => - completeWith { + parameters("jobId", 'format ?) { + case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) => val future = master ? RequestMasterState - future.map { state => - val masterState = state.asInstanceOf[MasterState] - - // A bit ugly an inefficient, but we won't have a number of jobs - // so large that it will make a significant difference. - (masterState.activeJobs ++ masterState.completedJobs).find(_.id == jobId) match { - case Some(job) => spark.deploy.master.html.job_details.render(job) - case _ => null + val jobInfo = for (masterState <- future.mapTo[MasterState]) yield { + masterState.activeJobs.find(_.id == jobId) match { + case Some(job) => job + case _ => masterState.completedJobs.find(_.id == jobId) match { + case Some(job) => job + case _ => null + } + } + } + respondWithMediaType(MediaTypes.`application/json`) { ctx => + ctx.complete(jobInfo.mapTo[JobInfo]) + } + case (jobId, _) => + completeWith { + val future = master ? RequestMasterState + future.map { state => + val masterState = state.asInstanceOf[MasterState] + + masterState.activeJobs.find(_.id == jobId) match { + case Some(job) => spark.deploy.master.html.job_details.render(job) + case _ => masterState.completedJobs.find(_.id == jobId) match { + case Some(job) => spark.deploy.master.html.job_details.render(job) + case _ => null + } + } } } - } } } ~ pathPrefix("static") { -- cgit v1.2.3 From 22445fbea9ed1575e49a1f9bb2251d98a57b9e4e Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Fri, 11 Jan 2013 13:30:49 -0800 Subject: attempt to sleep for more accurate time period, minor cleanup --- .../scala/spark/util/RateLimitedOutputStream.scala | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index d11ed163ce..3050213709 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -1,8 +1,10 @@ package spark.util import java.io.OutputStream +import java.util.concurrent.TimeUnit._ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { + val SyncIntervalNs = NANOSECONDS.convert(10, SECONDS) var lastSyncTime = System.nanoTime() var bytesWrittenSinceSync: Long = 0 @@ -28,20 +30,21 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu def waitToWrite(numBytes: Int) { while (true) { - val now = System.nanoTime() - val elapsed = math.max(now - lastSyncTime, 1) - val rate = bytesWrittenSinceSync.toDouble / (elapsed / 1.0e9) + val now = System.nanoTime + val elapsedSecs = SECONDS.convert(max(now - lastSyncTime, 1), NANOSECONDS) + val rate = bytesWrittenSinceSync.toDouble / elapsedSecs if (rate < bytesPerSec) { // It's okay to write; just update some variables and return bytesWrittenSinceSync += numBytes - if (now > lastSyncTime + (1e10).toLong) { - // Ten seconds have passed since lastSyncTime; let's resync + if (now > lastSyncTime + SyncIntervalNs) { + // Sync interval has passed; let's resync lastSyncTime = now bytesWrittenSinceSync = numBytes } - return } else { - Thread.sleep(5) + // Calculate how much time we should sleep to bring ourselves to the desired rate. + val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) + if (sleepTime > 0) Thread.sleep(sleepTime) } } } @@ -53,4 +56,4 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu override def close() { out.close() } -} \ No newline at end of file +} -- cgit v1.2.3 From ff10b3aa0970cc7224adc6bc73d99a7ffa30219f Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Fri, 11 Jan 2013 21:03:57 -0800 Subject: add missing return --- core/src/main/scala/spark/util/RateLimitedOutputStream.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index 3050213709..ed459c2544 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -41,6 +41,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu lastSyncTime = now bytesWrittenSinceSync = numBytes } + return } else { // Calculate how much time we should sleep to bring ourselves to the desired rate. val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) -- cgit v1.2.3 From 0cfea7a2ec467717fbe110f9b15163bea2719575 Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Fri, 11 Jan 2013 23:48:07 -0800 Subject: add unit test --- .../scala/spark/util/RateLimitedOutputStream.scala | 2 +- .../spark/util/RateLimitedOutputStreamSuite.scala | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index ed459c2544..16db7549b2 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -31,7 +31,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu def waitToWrite(numBytes: Int) { while (true) { val now = System.nanoTime - val elapsedSecs = SECONDS.convert(max(now - lastSyncTime, 1), NANOSECONDS) + val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) val rate = bytesWrittenSinceSync.toDouble / elapsedSecs if (rate < bytesPerSec) { // It's okay to write; just update some variables and return diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala new file mode 100644 index 0000000000..1dc45e0433 --- /dev/null +++ b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala @@ -0,0 +1,22 @@ +package spark.util + +import org.scalatest.FunSuite +import java.io.ByteArrayOutputStream +import java.util.concurrent.TimeUnit._ + +class RateLimitedOutputStreamSuite extends FunSuite { + + private def benchmark[U](f: => U): Long = { + val start = System.nanoTime + f + System.nanoTime - start + } + + test("write") { + val underlying = new ByteArrayOutputStream + val data = "X" * 1000 + val stream = new RateLimitedOutputStream(underlying, 100) + val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } + assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) + } +} -- cgit v1.2.3 From 2c77eeebb66a3d1337d45b5001be2b48724f9fd5 Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sat, 12 Jan 2013 00:13:45 -0800 Subject: correct test params --- core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala index 1dc45e0433..b392075482 100644 --- a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala +++ b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala @@ -14,8 +14,8 @@ class RateLimitedOutputStreamSuite extends FunSuite { test("write") { val underlying = new ByteArrayOutputStream - val data = "X" * 1000 - val stream = new RateLimitedOutputStream(underlying, 100) + val data = "X" * 41000 + val stream = new RateLimitedOutputStream(underlying, 10000) val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) } -- cgit v1.2.3 From ea20ae661888d871f70d5ed322cfe924c5a31dba Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sat, 12 Jan 2013 09:18:00 -0800 Subject: add one extra test --- core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala index b392075482..794063fb6d 100644 --- a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala +++ b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala @@ -18,5 +18,6 @@ class RateLimitedOutputStreamSuite extends FunSuite { val stream = new RateLimitedOutputStream(underlying, 10000) val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) + assert(underlying.toString("UTF-8") == data) } } -- cgit v1.2.3 From addff2c466d4b76043e612d4d28ab9de7f003298 Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sat, 12 Jan 2013 09:57:29 -0800 Subject: add comment --- core/src/main/scala/spark/util/RateLimitedOutputStream.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index 16db7549b2..ed3d2b66bb 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -44,6 +44,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu return } else { // Calculate how much time we should sleep to bring ourselves to the desired rate. + // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) if (sleepTime > 0) Thread.sleep(sleepTime) } -- cgit v1.2.3 From bbc56d85ed4eb4c3a09b20d5457f704f4b8a70c4 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sat, 12 Jan 2013 15:24:13 -0800 Subject: Rename environment variable for hadoop profiles to hadoopVersion --- bagel/pom.xml | 4 ++-- core/pom.xml | 4 ++-- examples/pom.xml | 4 ++-- pom.xml | 5 +++-- repl-bin/pom.xml | 4 ++-- repl/pom.xml | 4 ++-- 6 files changed, 13 insertions(+), 12 deletions(-) (limited to 'core') diff --git a/bagel/pom.xml b/bagel/pom.xml index c3461fb889..5f58347204 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -47,7 +47,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -79,7 +79,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/core/pom.xml b/core/pom.xml index c8ff625774..ad9fdcde2c 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -161,7 +161,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -218,7 +218,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/examples/pom.xml b/examples/pom.xml index d0b1e97747..3355deb6b7 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -47,7 +47,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -79,7 +79,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/pom.xml b/pom.xml index ae87813d4e..8f1af673a3 100644 --- a/pom.xml +++ b/pom.xml @@ -483,9 +483,10 @@ hadoop1 - !hadoop + !hadoopVersion + 1 @@ -504,7 +505,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index 54ae20659e..da91c0f3ab 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -72,7 +72,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -117,7 +117,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/repl/pom.xml b/repl/pom.xml index 3e979b93a6..38e883c7f8 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -74,7 +74,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -123,7 +123,7 @@ hadoop2 - hadoop + hadoopVersion 2 -- cgit v1.2.3 From 88d8f11365db84d46ff456495c07f664c91d1896 Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Sun, 13 Jan 2013 00:45:52 -0800 Subject: Add missing dependency spray-json to Maven build --- core/pom.xml | 4 ++++ pom.xml | 6 ++++++ 2 files changed, 10 insertions(+) (limited to 'core') diff --git a/core/pom.xml b/core/pom.xml index ad9fdcde2c..862d3ec37a 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -71,6 +71,10 @@ cc.spray spray-server + + cc.spray + spray-json_${scala.version} + org.tomdz.twirl twirl-api diff --git a/pom.xml b/pom.xml index 8f1af673a3..751189a9d8 100644 --- a/pom.xml +++ b/pom.xml @@ -54,6 +54,7 @@ 0.9.0-incubating 2.0.3 1.0-M2.1 + 1.1.1 1.6.1 4.1.2 @@ -222,6 +223,11 @@ spray-server ${spray.version} + + cc.spray + spray-json_${scala.version} + ${spray.json.version} + org.tomdz.twirl twirl-api -- cgit v1.2.3 From 2305a2c1d91273a93ee6b571b0cd4bcaa1b2969d Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sun, 13 Jan 2013 10:01:56 -0800 Subject: more code cleanup --- .../scala/spark/util/RateLimitedOutputStream.scala | 63 +++++++++++----------- 1 file changed, 32 insertions(+), 31 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index ed3d2b66bb..10790a9eee 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -1,11 +1,14 @@ package spark.util +import scala.annotation.tailrec + import java.io.OutputStream import java.util.concurrent.TimeUnit._ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { val SyncIntervalNs = NANOSECONDS.convert(10, SECONDS) - var lastSyncTime = System.nanoTime() + val ChunkSize = 8192 + var lastSyncTime = System.nanoTime var bytesWrittenSinceSync: Long = 0 override def write(b: Int) { @@ -17,37 +20,13 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu write(bytes, 0, bytes.length) } - override def write(bytes: Array[Byte], offset: Int, length: Int) { - val CHUNK_SIZE = 8192 - var pos = 0 - while (pos < length) { - val writeSize = math.min(length - pos, CHUNK_SIZE) + @tailrec + override final def write(bytes: Array[Byte], offset: Int, length: Int) { + val writeSize = math.min(length - offset, ChunkSize) + if (writeSize > 0) { waitToWrite(writeSize) - out.write(bytes, offset + pos, writeSize) - pos += writeSize - } - } - - def waitToWrite(numBytes: Int) { - while (true) { - val now = System.nanoTime - val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) - val rate = bytesWrittenSinceSync.toDouble / elapsedSecs - if (rate < bytesPerSec) { - // It's okay to write; just update some variables and return - bytesWrittenSinceSync += numBytes - if (now > lastSyncTime + SyncIntervalNs) { - // Sync interval has passed; let's resync - lastSyncTime = now - bytesWrittenSinceSync = numBytes - } - return - } else { - // Calculate how much time we should sleep to bring ourselves to the desired rate. - // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) - val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) - if (sleepTime > 0) Thread.sleep(sleepTime) - } + out.write(bytes, offset, writeSize) + write(bytes, offset + writeSize, length) } } @@ -58,4 +37,26 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu override def close() { out.close() } + + @tailrec + private def waitToWrite(numBytes: Int) { + val now = System.nanoTime + val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) + val rate = bytesWrittenSinceSync.toDouble / elapsedSecs + if (rate < bytesPerSec) { + // It's okay to write; just update some variables and return + bytesWrittenSinceSync += numBytes + if (now > lastSyncTime + SyncIntervalNs) { + // Sync interval has passed; let's resync + lastSyncTime = now + bytesWrittenSinceSync = numBytes + } + } else { + // Calculate how much time we should sleep to bring ourselves to the desired rate. + // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) + val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) + if (sleepTime > 0) Thread.sleep(sleepTime) + waitToWrite(numBytes) + } + } } -- cgit v1.2.3 From c31931af7eb01fbe2bb276bb6f428248128832b0 Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sun, 13 Jan 2013 10:39:47 -0800 Subject: switch to uppercase constants --- core/src/main/scala/spark/util/RateLimitedOutputStream.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index 10790a9eee..e3f00ea8c7 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -6,8 +6,8 @@ import java.io.OutputStream import java.util.concurrent.TimeUnit._ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { - val SyncIntervalNs = NANOSECONDS.convert(10, SECONDS) - val ChunkSize = 8192 + val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + val CHUNK_SIZE = 8192 var lastSyncTime = System.nanoTime var bytesWrittenSinceSync: Long = 0 @@ -22,7 +22,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu @tailrec override final def write(bytes: Array[Byte], offset: Int, length: Int) { - val writeSize = math.min(length - offset, ChunkSize) + val writeSize = math.min(length - offset, CHUNK_SIZE) if (writeSize > 0) { waitToWrite(writeSize) out.write(bytes, offset, writeSize) @@ -46,7 +46,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu if (rate < bytesPerSec) { // It's okay to write; just update some variables and return bytesWrittenSinceSync += numBytes - if (now > lastSyncTime + SyncIntervalNs) { + if (now > lastSyncTime + SYNC_INTERVAL) { // Sync interval has passed; let's resync lastSyncTime = now bytesWrittenSinceSync = numBytes -- cgit v1.2.3 From be7166146bf5692369272b85622d5316eccfd8e6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 13 Jan 2013 15:27:28 -0800 Subject: Removed the use of getOrElse to avoid Scala wrapper for every call. --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 2e051c81c8..ce5f171911 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -1,8 +1,8 @@ package spark.rdd import java.util.{HashMap => JHashMap} +import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -74,7 +74,14 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) val numRdds = split.deps.size val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] def getSeq(k: K): Seq[ArrayBuffer[Any]] = { - map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) + val seq = map.get(k) + if (seq != null) { + seq + } else { + val seq = Array.fill(numRdds)(new ArrayBuffer[Any]) + map.put(k, seq) + seq + } } for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, itsSplit) => { @@ -94,6 +101,6 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair) } } - map.iterator + JavaConversions.mapAsScalaMap(map).iterator } } -- cgit v1.2.3 From 72408e8dfacc24652f376d1ee4dd6f04edb54804 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 13 Jan 2013 19:34:07 -0800 Subject: Make filter preserve partitioner info, since it can --- core/src/main/scala/spark/rdd/FilteredRDD.scala | 3 ++- core/src/test/scala/spark/PartitioningSuite.scala | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index b148da28de..d46549b8b6 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -7,5 +7,6 @@ private[spark] class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) + override val partitioner = prev.partitioner // Since filter cannot change a partition's keys override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).filter(f) -} \ No newline at end of file +} diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index f09b602a7b..eb3c8f238f 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -106,6 +106,11 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter { assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) + + assert(grouped2.map(_ => 1).partitioner === None) + assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner) + assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner) + assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner) } test("partitioning Java arrays should fail") { -- cgit v1.2.3 From 0dbd411a562396e024c513936fde46b0d2f6d59d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 13 Jan 2013 21:08:35 -0800 Subject: Added documentation for PairDStreamFunctions. --- core/src/main/scala/spark/PairRDDFunctions.scala | 6 +- docs/streaming-programming-guide.md | 45 ++-- .../src/main/scala/spark/streaming/DStream.scala | 35 ++- .../spark/streaming/PairDStreamFunctions.scala | 293 ++++++++++++++++++++- .../scala/spark/streaming/util/RawTextHelper.scala | 2 +- 5 files changed, 331 insertions(+), 50 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 07ae2d647c..d95b66ad78 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -199,9 +199,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. */ def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { this.cogroup(other, partitioner).flatMapValues { diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 05a88ce7bd..b6da7af654 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -43,7 +43,7 @@ A complete list of input sources is available in the [StreamingContext API docum # DStream Operations -Once an input stream has been created, you can transform it using _stream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the stream by writing data out to an external source. +Once an input DStream has been created, you can transform it using _DStream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the DStream by writing data out to an external source. ## Transformations @@ -53,11 +53,11 @@ DStreams support many of the transformations available on normal Spark RDD's:
TransformationMeaning
map(func) Return a new stream formed by passing each element of the source through a function func. Returns a new DStream formed by passing each element of the source through a function func.
filter(func) Return a new stream formed by selecting those elements of the source on which func returns true. Returns a new stream formed by selecting those elements of the source on which func returns true.
flatMap(func)
cogroup(otherStream, [numTasks]) When called on streams of type (K, V) and (K, W), returns a stream of (K, Seq[V], Seq[W]) tuples. This operation is also called groupWith. When called on DStream of type (K, V) and (K, W), returns a DStream of (K, Seq[V], Seq[W]) tuples.
reduce(func) Create a new single-element stream by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel. Returns a new DStream of single-element RDDs by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel.
transform(func) Returns a new DStream by applying func (a RDD-to-RDD function) to every RDD of the stream. This can be used to do arbitrary RDD operations on the DStream.
-Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowTime, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. +Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowDuration, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. - - + - - + - - + -
TransformationMeaning
window(windowTime, slideTime) Return a new stream which is computed based on windowed batches of the source stream. windowTime is the width of the window and slideTime is the frequency during which the window is calculated. Both times must be multiples of the batch interval. + window(windowDuration, slideTime) Return a new stream which is computed based on windowed batches of the source stream. windowDuration is the width of the window and slideTime is the frequency during which the window is calculated. Both times must be multiples of the batch interval.
countByWindow(windowTime, slideTime) Return a sliding count of elements in the stream. windowTime and slideTime are exactly as defined in window(). + countByWindow(windowDuration, slideTime) Return a sliding count of elements in the stream. windowDuration and slideDuration are exactly as defined in window().
reduceByWindow(func, windowTime, slideTime) Return a new single-element stream, created by aggregating elements in the stream over a sliding interval using func. The function should be associative so that it can be computed correctly in parallel. windowTime and slideTime are exactly as defined in window(). + reduceByWindow(func, windowDuration, slideDuration) Return a new single-element stream, created by aggregating elements in the stream over a sliding interval using func. The function should be associative so that it can be computed correctly in parallel. windowDuration and slideDuration are exactly as defined in window().
groupByKeyAndWindow(windowTime, slideTime, [numTasks]) + groupByKeyAndWindow(windowDuration, slideDuration, [numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, Seq[V]) pairs over a sliding window.
-Note: By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional numTasks argument to set a different number of tasks. windowTime and slideTime are exactly as defined in window(). +Note: By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional numTasks argument to set a different number of tasks. windowDuration and slideDuration are exactly as defined in window().
reduceByKeyAndWindow(func, [numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, V) pairs where the values for each key are aggregated using the given reduce function over batches within a sliding window. Like in groupByKeyAndWindow, the number of reduce tasks is configurable through an optional second argument. - windowTime and slideTime are exactly as defined in window(). + windowDuration and slideDuration are exactly as defined in window().
countByKeyAndWindow([numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, Int) pairs where the values for each key are the count within a sliding window. Like in countByKeyAndWindow, the number of reduce tasks is configurable through an optional second argument. - windowTime and slideTime are exactly as defined in window(). + windowDuration and slideDuration are exactly as defined in window().
+A complete list of DStream operations is available in the API documentation of [DStream](api/streaming/index.html#spark.streaming.DStream) and [PairDStreamFunctions](api/streaming/index.html#spark.streaming.PairDStreamFunctions). ## Output Operations When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined: @@ -144,7 +149,7 @@ When an output operator is called, it triggers the computation of a stream. Curr - + @@ -155,18 +160,18 @@ When an output operator is called, it triggers the computation of a stream. Curr - - + - +
OperatorMeaning
foreachRDD(func) foreach(func) The fundamental output operator. Applies a function, func, to each RDD generated from the stream. This function should have side effects, such as printing output, saving the RDD to external files, or writing it over the network to an external system.
saveAsObjectFiles(prefix, [suffix]) Save this DStream's contents as a SequenceFile of serialized objects. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". + Save this DStream's contents as a SequenceFile of serialized objects. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
saveAsTextFiles(prefix, [suffix]) Save this DStream's contents as a text files. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". Save this DStream's contents as a text files. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
saveAsHadoopFiles(prefix, [suffix]) Save this DStream's contents as a Hadoop file. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". Save this DStream's contents as a Hadoop file. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index c89fb7723e..d94548a4f3 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -471,7 +471,7 @@ abstract class DStream[T: ClassManifest] ( * Returns a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Int] = this.map(_ => 1).reduce(_ + _) + def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _) /** * Applies a function to each RDD in this DStream. This is an output operator, so @@ -529,17 +529,16 @@ abstract class DStream[T: ClassManifest] ( * Return a new DStream which is computed based on windowed batches of this DStream. * The new DStream generates RDDs with the same interval as this DStream. * @param windowDuration width of the window; must be a multiple of this DStream's interval. - * @return */ def window(windowDuration: Duration): DStream[T] = window(windowDuration, this.slideDuration) /** * Return a new DStream which is computed based on windowed batches of this DStream. - * @param windowDuration duration (i.e., width) of the window; - * must be a multiple of this DStream's interval + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's interval + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval */ def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = { new WindowedDStream(this, windowDuration, slideDuration) @@ -548,16 +547,22 @@ abstract class DStream[T: ClassManifest] ( /** * Returns a new DStream which computed based on tumbling window on this DStream. * This is equivalent to window(batchTime, batchTime). - * @param batchTime tumbling window duration; must be a multiple of this DStream's interval + * @param batchTime tumbling window duration; must be a multiple of this DStream's + * batching interval */ def tumble(batchTime: Duration): DStream[T] = window(batchTime, batchTime) /** * Returns a new DStream in which each RDD has a single element generated by reducing all - * elements in a window over this DStream. windowDuration and slideDuration are as defined in the - * window() operation. This is equivalent to window(windowDuration, slideDuration).reduce(reduceFunc) + * elements in a window over this DStream. windowDuration and slideDuration are as defined + * in the window() operation. This is equivalent to + * window(windowDuration, slideDuration).reduce(reduceFunc) */ - def reduceByWindow(reduceFunc: (T, T) => T, windowDuration: Duration, slideDuration: Duration): DStream[T] = { + def reduceByWindow( + reduceFunc: (T, T) => T, + windowDuration: Duration, + slideDuration: Duration + ): DStream[T] = { this.window(windowDuration, slideDuration).reduce(reduceFunc) } @@ -577,8 +582,8 @@ abstract class DStream[T: ClassManifest] ( * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the * window() operation. This is equivalent to window(windowDuration, slideDuration).count() */ - def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Int] = { - this.map(_ => 1).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) + def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Long] = { + this.map(_ => 1L).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) } /** @@ -612,6 +617,8 @@ abstract class DStream[T: ClassManifest] ( /** * Saves each RDD in this DStream as a Sequence file of serialized objects. + * The file name at each batch interval is generated based on `prefix` and + * `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsObjectFiles(prefix: String, suffix: String = "") { val saveFunc = (rdd: RDD[T], time: Time) => { @@ -622,7 +629,9 @@ abstract class DStream[T: ClassManifest] ( } /** - * Saves each RDD in this DStream as at text file, using string representation of elements. + * Saves each RDD in this DStream as at text file, using string representation + * of elements. The file name at each batch interval is generated based on + * `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsTextFiles(prefix: String, suffix: String = "") { val saveFunc = (rdd: RDD[T], time: Time) => { diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 482d01300d..3952457339 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -25,34 +25,76 @@ extends Serializable { new HashPartitioner(numPartitions) } + /** + * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs are grouped into a + * single sequence to generate the RDDs of the new DStream. Hash partitioning is + * used to generate the RDDs with Spark's default number of partitions. + */ def groupByKey(): DStream[(K, Seq[V])] = { groupByKey(defaultPartitioner()) } + /** + * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs are grouped into a + * single sequence to generate the RDDs of the new DStream. Hash partitioning is + * used to generate the RDDs with `numPartitions` partitions. + */ def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = { groupByKey(defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs are grouped into a + * single sequence to generate the RDDs of the new DStream. [[spark.Partitioner]] + * is used to control the partitioning of each RDD. + */ def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = { val createCombiner = (v: V) => ArrayBuffer[V](v) val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v) val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2) - combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner).asInstanceOf[DStream[(K, Seq[V])]] + combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner) + .asInstanceOf[DStream[(K, Seq[V])]] } + /** + * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs is merged using the + * associative reduce function to generate the RDDs of the new DStream. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + */ def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = { reduceByKey(reduceFunc, defaultPartitioner()) } + /** + * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs is merged using the + * associative reduce function to generate the RDDs of the new DStream. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + */ def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): DStream[(K, V)] = { reduceByKey(reduceFunc, defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs is merged using the + * associative reduce function to generate the RDDs of the new DStream. + * [[spark.Partitioner]] is used to control the partitioning of each RDD. + */ def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) } + /** + * Generic function to combine elements of each key in DStream's RDDs using custom function. + * This is similar to the combineByKey for RDDs. Please refer to combineByKey in + * [[spark.PairRDDFunctions]] for more information. + */ def combineByKey[C: ClassManifest]( createCombiner: V => C, mergeValue: (C, V) => C, @@ -61,14 +103,52 @@ extends Serializable { new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner) } + /** + * Creates a new DStream by counting the number of values of each key in each RDD + * of `this` DStream. Hash partitioning is used to generate the RDDs with Spark's + * `numPartitions` partitions. + */ def countByKey(numPartitions: Int = self.ssc.sc.defaultParallelism): DStream[(K, Long)] = { self.map(x => (x._1, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) } + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * The new DStream generates RDDs with the same interval as this DStream. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + */ + def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Seq[V])] = { + groupByKeyAndWindow(windowDuration, self.slideDuration, defaultPartitioner()) + } + + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration): DStream[(K, Seq[V])] = { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner()) } + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def groupByKeyAndWindow( windowDuration: Duration, slideDuration: Duration, @@ -77,6 +157,16 @@ extends Serializable { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ def groupByKeyAndWindow( windowDuration: Duration, slideDuration: Duration, @@ -85,6 +175,15 @@ extends Serializable { self.window(windowDuration, slideDuration).groupByKey(partitioner) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * The new DStream generates RDDs with the same interval as this DStream. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration @@ -92,6 +191,17 @@ extends Serializable { reduceByKeyAndWindow(reduceFunc, windowDuration, self.slideDuration, defaultPartitioner()) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration, @@ -100,6 +210,18 @@ extends Serializable { reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner()) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration, @@ -109,6 +231,17 @@ extends Serializable { reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration, @@ -121,12 +254,23 @@ extends Serializable { .reduceByKey(cleanedReduceFunc, partitioner) } - // This method is the efficient sliding window reduce operation, - // which requires the specification of an inverse reduce function, - // so that new elements introduced in the window can be "added" using - // reduceFunc to the previous window's result and old elements can be - // "subtracted using invReduceFunc. - + /** + * Creates a new DStream by reducing over a window in a smarter way. + * The reduced value of over a new window is calculated incrementally by using the + * old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, @@ -138,6 +282,24 @@ extends Serializable { reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner()) } + /** + * Creates a new DStream by reducing over a window in a smarter way. + * The reduced value of over a new window is calculated incrementally by using the + * old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, @@ -150,6 +312,23 @@ extends Serializable { reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by reducing over a window in a smarter way. + * The reduced value of over a new window is calculated incrementally by using the + * old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, @@ -164,6 +343,16 @@ extends Serializable { self, cleanedReduceFunc, cleanedInvReduceFunc, windowDuration, slideDuration, partitioner) } + /** + * Creates a new DStream by counting the number of values for each key over a window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def countByKeyAndWindow( windowDuration: Duration, slideDuration: Duration, @@ -179,17 +368,30 @@ extends Serializable { ) } - // TODO: - // - // - // - // + /** + * Creates a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key from + * `this` DStream. Hash partitioning is used to generate the RDDs with Spark's default + * number of partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @tparam S State type + */ def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S] ): DStream[(K, S)] = { updateStateByKey(updateFunc, defaultPartitioner()) } + /** + * Creates a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key from + * `this` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param numPartitions Number of partitions of each RDD in the new DStream. + * @tparam S State type + */ def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S], numPartitions: Int @@ -197,6 +399,15 @@ extends Serializable { updateStateByKey(updateFunc, defaultPartitioner(numPartitions)) } + /** + * Creates a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key from + * `this` DStream. [[spark.Partitioner]] is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @tparam S State type + */ def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner @@ -207,6 +418,19 @@ extends Serializable { updateStateByKey(newUpdateFunc, partitioner, true) } + /** + * Creates a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key from + * `this` DStream. [[spark.Partitioner]] is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. Note, that + * this function may generate a different a tuple with a different key + * than the input key. It is up to the developer to decide whether to + * remember the partitioner despite the key being changed. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs. + * @tparam S State type + */ def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, @@ -226,10 +450,24 @@ extends Serializable { new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc) } + /** + * Cogroups `this` DStream with `other` DStream. Each RDD of the new DStream will + * be generated by cogrouping RDDs from`this`and `other` DStreams. Therefore, for + * each key k in corresponding RDDs of `this` or `other` DStreams, the generated RDD + * will contains a tuple with the list of values for that key in both RDDs. + * HashPartitioner is used to partition each generated RDD into default number of partitions. + */ def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { cogroup(other, defaultPartitioner()) } + /** + * Cogroups `this` DStream with `other` DStream. Each RDD of the new DStream will + * be generated by cogrouping RDDs from`this`and `other` DStreams. Therefore, for + * each key k in corresponding RDDs of `this` or `other` DStreams, the generated RDD + * will contains a tuple with the list of values for that key in both RDDs. + * Partitioner is used to partition each generated RDD. + */ def cogroup[W: ClassManifest]( other: DStream[(K, W)], partitioner: Partitioner @@ -249,11 +487,24 @@ extends Serializable { } } + /** + * Joins `this` DStream with `other` DStream. Each RDD of the new DStream will + * be generated by joining RDDs from `this` and `other` DStreams. HashPartitioner is used + * to partition each generated RDD into default number of partitions. + */ def join[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (V, W))] = { join[W](other, defaultPartitioner()) } - def join[W: ClassManifest](other: DStream[(K, W)], partitioner: Partitioner): DStream[(K, (V, W))] = { + /** + * Joins `this` DStream with `other` DStream, that is, each RDD of the new DStream will + * be generated by joining RDDs from `this` and other DStream. Uses the given + * Partitioner to partition each generated RDD. + */ + def join[W: ClassManifest]( + other: DStream[(K, W)], + partitioner: Partitioner + ): DStream[(K, (V, W))] = { this.cogroup(other, partitioner) .flatMapValues{ case (vs, ws) => @@ -261,6 +512,10 @@ extends Serializable { } } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + */ def saveAsHadoopFiles[F <: OutputFormat[K, V]]( prefix: String, suffix: String @@ -268,6 +523,10 @@ extends Serializable { saveAsHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + */ def saveAsHadoopFiles( prefix: String, suffix: String, @@ -283,6 +542,10 @@ extends Serializable { self.foreach(saveFunc) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]]( prefix: String, suffix: String @@ -290,6 +553,10 @@ extends Serializable { saveAsNewAPIHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ def saveAsNewAPIHadoopFiles( prefix: String, suffix: String, diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala index f31ae39a16..03749d4a94 100644 --- a/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala @@ -81,7 +81,7 @@ object RawTextHelper { * before real workload starts. */ def warmUp(sc: SparkContext) { - for(i <- 0 to 4) { + for(i <- 0 to 1) { sc.parallelize(1 to 200000, 1000) .map(_ % 1331).map(_.toString) .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) -- cgit v1.2.3 From 131be5d62ef6b770de5106eb268a45bca385b599 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 14 Jan 2013 03:28:25 -0800 Subject: Fixed bug in RDD checkpointing. --- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 86c63ca2f4..6f00f6ac73 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -80,12 +80,12 @@ private[spark] object CheckpointRDD extends Logging { val serializer = SparkEnv.get.serializer.newInstance() val serializeStream = serializer.serializeStream(fileOutputStream) serializeStream.writeAll(iterator) - fileOutputStream.close() + serializeStream.close() if (!fs.rename(tempOutputPath, finalOutputPath)) { if (!fs.delete(finalOutputPath, true)) { throw new IOException("Checkpoint failed: failed to delete earlier output of task " - + context.attemptId); + + context.attemptId) } if (!fs.rename(tempOutputPath, finalOutputPath)) { throw new IOException("Checkpoint failed: failed to save output of task: " @@ -119,7 +119,7 @@ private[spark] object CheckpointRDD extends Logging { val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) val path = new Path(hdfsPath, "temp") val fs = path.getFileSystem(new Configuration()) - sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 10) _) + sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same") assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same") -- cgit v1.2.3 From 273fb5cc109ac0a032f84c1566ae908cd0eb27b6 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Thu, 3 Jan 2013 14:09:56 -0800 Subject: Throw FetchFailedException for cached missing locs --- core/src/main/scala/spark/MapOutputTracker.scala | 36 +++++++++++++++++------- 1 file changed, 26 insertions(+), 10 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 70eb9f702e..9f2aa76830 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -139,8 +139,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea case e: InterruptedException => } } - return mapStatuses.get(shuffleId).map(status => - (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, + mapStatuses.get(shuffleId)) } else { fetching += shuffleId } @@ -156,21 +156,15 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea fetchedStatuses = deserializeStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) - if (fetchedStatuses.contains(null)) { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing an output location for shuffle " + shuffleId)) - } } finally { fetching.synchronized { fetching -= shuffleId fetching.notifyAll() } } - return fetchedStatuses.map(s => - (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } else { - return statuses.map(s => - (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) } } @@ -258,6 +252,28 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea private[spark] object MapOutputTracker { private val LOG_BASE = 1.1 + // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If + // any of the statuses is null (indicating a missing location due to a failed mapper), + // throw a FetchFailedException. + def convertMapStatuses( + shuffleId: Int, + reduceId: Int, + statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + if (statuses == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing all output locations for shuffle " + shuffleId)) + } + statuses.map { + status => + if (status == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing an output location for shuffle " + shuffleId)) + } else { + (status.address, decompressSize(status.compressedSizes(reduceId))) + } + } + } + /** * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. * We do this by encoding the log base 1.1 of the size as an integer, which can support -- cgit v1.2.3 From 7ba34bc007ec10d12b2a871749f32232cdbc0d9c Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 14 Jan 2013 15:24:08 -0800 Subject: Additional tests for MapOutputTracker. --- .../test/scala/spark/MapOutputTrackerSuite.scala | 82 +++++++++++++++++++++- 1 file changed, 80 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 5b4b198960..6c6f82e274 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -1,12 +1,18 @@ package spark import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter import akka.actor._ import spark.scheduler.MapStatus import spark.storage.BlockManagerId +import spark.util.AkkaUtils -class MapOutputTrackerSuite extends FunSuite { +class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { + after { + System.clearProperty("spark.master.port") + } + test("compressSize") { assert(MapOutputTracker.compressSize(0L) === 0) assert(MapOutputTracker.compressSize(1L) === 1) @@ -71,6 +77,78 @@ class MapOutputTrackerSuite extends FunSuite { // The remaining reduce task might try to grab the output dispite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. - intercept[Exception] { tracker.getServerStatuses(10, 1) } + intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } + } + + test("remote fetch") { + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem("test", "localhost", 0) + System.setProperty("spark.master.port", boundPort.toString) + val masterTracker = new MapOutputTracker(actorSystem, true) + val slaveTracker = new MapOutputTracker(actorSystem, false) + masterTracker.registerShuffle(10, 1) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((new BlockManagerId("hostA", 1000), size1000))) + + masterTracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + } + + test("simulatenous fetch fails") { + val dummyActorSystem = ActorSystem("testDummy") + val dummyTracker = new MapOutputTracker(dummyActorSystem, true) + dummyTracker.registerShuffle(10, 1) + // val compressedSize1000 = MapOutputTracker.compressSize(1000L) + // val size100 = MapOutputTracker.decompressSize(compressedSize1000) + // dummyTracker.registerMapOutput(10, 0, new MapStatus( + // new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + val serializedMessage = dummyTracker.getSerializedLocations(10) + + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem("test", "localhost", 0) + System.setProperty("spark.master.port", boundPort.toString) + val delayResponseLock = new java.lang.Object + val delayResponseActor = actorSystem.actorOf(Props(new Actor { + override def receive = { + case GetMapOutputStatuses(shuffleId: Int, requester: String) => + delayResponseLock.synchronized { + sender ! serializedMessage + } + } + }), name = "MapOutputTracker") + val slaveTracker = new MapOutputTracker(actorSystem, false) + var firstFailed = false + var secondFailed = false + val firstFetch = new Thread { + override def run() { + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + firstFailed = true + } + } + val secondFetch = new Thread { + override def run() { + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + secondFailed = true + } + } + delayResponseLock.synchronized { + firstFetch.start + secondFetch.start + } + firstFetch.join + secondFetch.join + assert(firstFailed && secondFailed) } } -- cgit v1.2.3 From b0389997972d383c3aaa87924b725dee70b18d8e Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 14 Jan 2013 17:04:44 -0800 Subject: Fix accidental spark.master.host reuse --- core/src/test/scala/spark/MapOutputTrackerSuite.scala | 2 ++ 1 file changed, 2 insertions(+) (limited to 'core') diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 6c6f82e274..aa1d8ac7e6 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -81,6 +81,7 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { } test("remote fetch") { + System.clearProperty("spark.master.host") val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) System.setProperty("spark.master.port", boundPort.toString) @@ -107,6 +108,7 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { } test("simulatenous fetch fails") { + System.clearProperty("spark.master.host") val dummyActorSystem = ActorSystem("testDummy") val dummyTracker = new MapOutputTracker(dummyActorSystem, true) dummyTracker.registerShuffle(10, 1) -- cgit v1.2.3 From dd583b7ebf0e6620ec8e35424b59db451febe3e8 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 15 Jan 2013 10:52:06 -0600 Subject: Call executeOnCompleteCallbacks in a finally block. --- core/src/main/scala/spark/scheduler/ResultTask.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index e492279b4e..2aad7956b4 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -15,9 +15,11 @@ private[spark] class ResultTask[T, U]( override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) - val result = func(context, rdd.iterator(split, context)) - context.executeOnCompleteCallbacks() - result + try { + func(context, rdd.iterator(split, context)) + } finally { + context.executeOnCompleteCallbacks() + } } override def preferredLocations: Seq[String] = locs -- cgit v1.2.3 From d228bff440395e8e6b8d67483467dde65b08ab40 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 15 Jan 2013 11:48:50 -0600 Subject: Add a test. --- .../scala/spark/scheduler/TaskContextSuite.scala | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 core/src/test/scala/spark/scheduler/TaskContextSuite.scala (limited to 'core') diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala new file mode 100644 index 0000000000..f937877340 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala @@ -0,0 +1,43 @@ +package spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import spark.TaskContext +import spark.RDD +import spark.SparkContext +import spark.Split + +class TaskContextSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + + after { + if (sc != null) { + sc.stop() + sc = null + } + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + } + + test("Calls executeOnCompleteCallbacks after failure") { + var completed = false + sc = new SparkContext("local", "test") + val rdd = new RDD[String](sc) { + override val splits = Array[Split](StubSplit(0)) + override val dependencies = List() + override def compute(split: Split, context: TaskContext) = { + context.addOnCompleteCallback(() => completed = true) + sys.error("failed") + } + } + val func = (c: TaskContext, i: Iterator[String]) => i.next + val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0) + intercept[RuntimeException] { + task.run(0) + } + assert(completed === true) + } + + case class StubSplit(val index: Int) extends Split +} \ No newline at end of file -- cgit v1.2.3 From 74d3b23929758328c2a7879381669d81bf899396 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 15 Jan 2013 14:03:28 -0600 Subject: Add spark.executor.memory to differentiate executor memory from spark-shell memory. --- core/src/main/scala/spark/SparkContext.scala | 4 ++-- core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala | 3 +-- .../spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 11 +++++------ 3 files changed, 8 insertions(+), 10 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bbf8272eb3..a5a1b75944 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -111,8 +111,8 @@ class SparkContext( // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() - for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", - "SPARK_TESTING")) { + // Note: SPARK_MEM isn't included because it's set directly in ExecutorRunner + for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) { val value = System.getenv(key) if (value != null) { executorEnvs(key) = value diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index beceb55ecd..2f2ea617ff 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -118,8 +118,7 @@ private[spark] class ExecutorRunner( for ((key, value) <- jobDesc.command.environment) { env.put(key, value) } - env.put("SPARK_CORES", cores.toString) - env.put("SPARK_MEMORY", memory.toString) + env.put("SPARK_MEM", memory.toString) // In case we are running this from within the Spark Shell, avoid creating a "scala" // parent process for the executor command env.put("SPARK_LAUNCH_WITH_SCALA", "0") diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index e2301347e5..f2fb244b24 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -23,12 +23,11 @@ private[spark] class SparkDeploySchedulerBackend( // Memory used by each executor (in megabytes) val executorMemory = { - if (System.getenv("SPARK_MEM") != null) { - Utils.memoryStringToMb(System.getenv("SPARK_MEM")) - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - } else { - 512 - } + // TODO: Might need to add some extra memory for the non-heap parts of the JVM + Option(System.getProperty("spark.executor.memory")) + .orElse(Option(System.getenv("SPARK_MEM"))) + .map(Utils.memoryStringToMb) + .getOrElse(512) } override def start() { -- cgit v1.2.3 From 4078623b9f2a338d4992c3dfd3af3a5550615180 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 15 Jan 2013 12:05:54 -0800 Subject: Remove broken attempt to test fetching case. --- .../test/scala/spark/MapOutputTrackerSuite.scala | 48 +--------------------- 1 file changed, 2 insertions(+), 46 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index aa1d8ac7e6..d3dd3a8fa4 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -105,52 +105,8 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - } - - test("simulatenous fetch fails") { - System.clearProperty("spark.master.host") - val dummyActorSystem = ActorSystem("testDummy") - val dummyTracker = new MapOutputTracker(dummyActorSystem, true) - dummyTracker.registerShuffle(10, 1) - // val compressedSize1000 = MapOutputTracker.compressSize(1000L) - // val size100 = MapOutputTracker.decompressSize(compressedSize1000) - // dummyTracker.registerMapOutput(10, 0, new MapStatus( - // new BlockManagerId("hostA", 1000), Array(compressedSize1000))) - val serializedMessage = dummyTracker.getSerializedLocations(10) - val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("test", "localhost", 0) - System.setProperty("spark.master.port", boundPort.toString) - val delayResponseLock = new java.lang.Object - val delayResponseActor = actorSystem.actorOf(Props(new Actor { - override def receive = { - case GetMapOutputStatuses(shuffleId: Int, requester: String) => - delayResponseLock.synchronized { - sender ! serializedMessage - } - } - }), name = "MapOutputTracker") - val slaveTracker = new MapOutputTracker(actorSystem, false) - var firstFailed = false - var secondFailed = false - val firstFetch = new Thread { - override def run() { - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - firstFailed = true - } - } - val secondFetch = new Thread { - override def run() { - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - secondFailed = true - } - } - delayResponseLock.synchronized { - firstFetch.start - secondFetch.start - } - firstFetch.join - secondFetch.join - assert(firstFailed && secondFailed) + // failure should be cached + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } } } -- cgit v1.2.3 From a805ac4a7cdd520b6141dd885c780c526bb54ba6 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Jan 2013 10:55:26 -0800 Subject: Disabled checkpoint for PairwiseRDD (pySpark). --- core/src/main/scala/spark/api/python/PythonRDD.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 276035a9ad..0138b22d38 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -138,6 +138,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } + override def checkpoint() { } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } -- cgit v1.2.3 From eae698f755f41fd8bdff94c498df314ed74aa3c1 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 16 Jan 2013 12:21:37 -0800 Subject: remove unused thread pool --- core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala | 3 --- 1 file changed, 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 915f71ba9f..a29bf974d2 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -24,9 +24,6 @@ private[spark] class StandaloneExecutorBackend( with ExecutorBackend with Logging { - val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) - var master: ActorRef = null override def preStart() { -- cgit v1.2.3 From 742bc841adb2a57b05e7a155681a162ab9dfa2c1 Mon Sep 17 00:00:00 2001 From: Fernand Pajot Date: Thu, 17 Jan 2013 16:56:11 -0800 Subject: changed HttpBroadcast server cache to be in spark.local.dir instead of java.io.tmpdir --- core/src/main/scala/spark/broadcast/HttpBroadcast.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 7eb4ddb74f..96dc28f12a 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -89,7 +89,7 @@ private object HttpBroadcast extends Logging { } private def createServer() { - broadcastDir = Utils.createTempDir() + broadcastDir = Utils.createTempDir(System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) server = new HttpServer(broadcastDir) server.start() serverUri = server.uri -- cgit v1.2.3 From 54c0f9f185576e9b844fa8f81ca410f188daa51c Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 17 Jan 2013 17:40:55 -0800 Subject: Fix code that assumed spark.local.dir is only a single directory --- core/src/main/scala/spark/Utils.scala | 11 ++++++++++- core/src/main/scala/spark/broadcast/HttpBroadcast.scala | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 0e7007459d..aeed5d2f32 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -134,7 +134,7 @@ private object Utils extends Logging { */ def fetchFile(url: String, targetDir: File) { val filename = url.split("/").last - val tempDir = System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")) + val tempDir = getLocalDir val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) val targetFile = new File(targetDir, filename) val uri = new URI(url) @@ -204,6 +204,15 @@ private object Utils extends Logging { FileUtil.chmod(filename, "a+x") } + /** + * Get a temporary directory using Spark's spark.local.dir property, if set. This will always + * return a single directory, even though the spark.local.dir property might be a list of + * multiple paths. + */ + def getLocalDir: String = { + System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) + } + /** * Shuffle the elements of a collection into a random order, returning the * result in a new collection. Unlike scala.util.Random.shuffle, this method diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 96dc28f12a..856a4683a9 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -89,7 +89,7 @@ private object HttpBroadcast extends Logging { } private def createServer() { - broadcastDir = Utils.createTempDir(System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) + broadcastDir = Utils.createTempDir(Utils.getLocalDir) server = new HttpServer(broadcastDir) server.start() serverUri = server.uri -- cgit v1.2.3 From d5570c7968baba1c1fe86c68dc1c388fae23907b Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 16 Jan 2013 21:07:49 -0800 Subject: Adding checkpointing to Java API --- .../main/scala/spark/api/java/JavaRDDLike.scala | 28 ++++++++++++++++++++++ .../scala/spark/api/java/JavaSparkContext.scala | 26 ++++++++++++++++++++ core/src/test/scala/spark/JavaAPISuite.java | 27 +++++++++++++++++++++ 3 files changed, 81 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 81d3a94466..958f5c26a1 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -9,6 +9,7 @@ import spark.api.java.JavaPairRDD._ import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} import spark.partial.{PartialResult, BoundedDouble} import spark.storage.StorageLevel +import com.google.common.base.Optional trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { @@ -298,4 +299,31 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Save this RDD as a SequenceFile of serialized objects. */ def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path) + + /** + * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` + * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. + * This is used to truncate very long lineages. In the current implementation, Spark will save + * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. + * Hence, it is strongly recommended to use checkpoint() on RDDs when + * (i) checkpoint() is called before the any job has been executed on this RDD. + * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will + * require recomputation. + */ + def checkpoint() = rdd.checkpoint() + + /** + * Return whether this RDD has been checkpointed or not + */ + def isCheckpointed(): Boolean = rdd.isCheckpointed() + + /** + * Gets the name of the file to which this RDD was checkpointed + */ + def getCheckpointFile(): Optional[String] = { + rdd.getCheckpointFile match { + case Some(file) => Optional.of(file) + case _ => Optional.absent() + } + } } diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index bf9ad7a200..22bfa2280d 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -342,6 +342,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def clearFiles() { sc.clearFiles() } + + /** + * Set the directory under which RDDs are going to be checkpointed. This method will + * create this directory and will throw an exception of the path already exists (to avoid + * overwriting existing files may be overwritten). The directory will be deleted on exit + * if indicated. + */ + def setCheckpointDir(dir: String, useExisting: Boolean) { + sc.setCheckpointDir(dir, useExisting) + } + + /** + * Set the directory under which RDDs are going to be checkpointed. This method will + * create this directory and will throw an exception of the path already exists (to avoid + * overwriting existing files may be overwritten). The directory will be deleted on exit + * if indicated. + */ + def setCheckpointDir(dir: String) { + sc.setCheckpointDir(dir) + } + + protected def checkpointFile[T](path: String): JavaRDD[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + new JavaRDD(sc.checkpointFile(path)) + } } object JavaSparkContext { diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index b99e790093..0b5354774b 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -625,4 +625,31 @@ public class JavaAPISuite implements Serializable { }); Assert.assertEquals((Float) 25.0f, floatAccum.value()); } + + @Test + public void checkpointAndComputation() { + File tempDir = Files.createTempDir(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + Assert.assertEquals(false, rdd.isCheckpointed()); + rdd.checkpoint(); + rdd.count(); // Forces the DAG to cause a checkpoint + Assert.assertEquals(true, rdd.isCheckpointed()); + Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect()); + } + + @Test + public void checkpointAndRestore() { + File tempDir = Files.createTempDir(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + Assert.assertEquals(false, rdd.isCheckpointed()); + rdd.checkpoint(); + rdd.count(); // Forces the DAG to cause a checkpoint + Assert.assertEquals(true, rdd.isCheckpointed()); + + Assert.assertTrue(rdd.getCheckpointFile().isPresent()); + JavaRDD recovered = sc.checkpointFile(rdd.getCheckpointFile().get()); + Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); + } } -- cgit v1.2.3 From 214345ceace634ec9cc83c4c85b233b699e0d219 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 19 Jan 2013 23:50:17 -0800 Subject: Fixed issue https://spark-project.atlassian.net/browse/STREAMING-29, along with updates to doc comments in SparkContext.checkpoint(). --- core/src/main/scala/spark/RDD.scala | 17 ++++++++--------- core/src/main/scala/spark/RDDCheckpointData.scala | 2 +- core/src/main/scala/spark/SparkContext.scala | 13 +++++++------ streaming/src/main/scala/spark/streaming/DStream.scala | 8 +++++++- 4 files changed, 23 insertions(+), 17 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index a9f2e86455..e0d2eabb1d 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -549,17 +549,16 @@ abstract class RDD[T: ClassManifest]( } /** - * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` - * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. - * This is used to truncate very long lineages. In the current implementation, Spark will save - * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. - * Hence, it is strongly recommended to use checkpoint() on RDDs when - * (i) checkpoint() is called before the any job has been executed on this RDD. - * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will - * require recomputation. + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. */ def checkpoint() { - if (checkpointData.isEmpty) { + if (context.checkpointDir.isEmpty) { + throw new Exception("Checkpoint directory has not been set in the SparkContext") + } else if (checkpointData.isEmpty) { checkpointData = Some(new RDDCheckpointData(this)) checkpointData.get.markForCheckpoint() } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index d845a522e4..18df530b7d 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -63,7 +63,7 @@ extends Logging with Serializable { } // Save to file, and reload it as an RDD - val path = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString + val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _) val newRDD = new CheckpointRDD[T](rdd.context, path) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 88cf357ebf..7f3259d982 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -184,7 +184,7 @@ class SparkContext( private var dagScheduler = new DAGScheduler(taskScheduler) - private[spark] var checkpointDir: String = null + private[spark] var checkpointDir: Option[String] = None // Methods for creating RDDs @@ -595,10 +595,11 @@ class SparkContext( } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists and useExisting is set to true, then the + * exisiting directory will be used. Otherwise an exception will be thrown to + * prevent accidental overriding of checkpoint files in the existing directory. */ def setCheckpointDir(dir: String, useExisting: Boolean = false) { val path = new Path(dir) @@ -610,7 +611,7 @@ class SparkContext( fs.mkdirs(path) } } - checkpointDir = dir + checkpointDir = Some(dir) } /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */ diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index fbe3cebd6d..c4442b6a0c 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -154,10 +154,16 @@ abstract class DStream[T: ClassManifest] ( assert( !mustCheckpoint || checkpointDuration != null, - "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " + + "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set." + " Please use DStream.checkpoint() to set the interval." ) + assert( + checkpointDuration == null || ssc.sc.checkpointDir.isDefined, + "The checkpoint directory has not been set. Please use StreamingContext.checkpoint()" + + " or SparkContext.checkpoint() to set the checkpoint directory." + ) + assert( checkpointDuration == null || checkpointDuration >= slideDuration, "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + -- cgit v1.2.3 From 8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 01:57:44 -0800 Subject: Added accumulators to PySpark --- .../main/scala/spark/api/python/PythonRDD.scala | 83 +++++++++-- python/pyspark/__init__.py | 4 + python/pyspark/accumulators.py | 166 +++++++++++++++++++++ python/pyspark/context.py | 38 +++++ python/pyspark/rdd.py | 2 +- python/pyspark/serializers.py | 7 +- python/pyspark/shell.py | 4 +- python/pyspark/worker.py | 7 +- 8 files changed, 290 insertions(+), 21 deletions(-) create mode 100644 python/pyspark/accumulators.py (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index f431ef28d3..fb13e84658 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -1,7 +1,8 @@ package spark.api.python import java.io._ -import java.util.{List => JList} +import java.net._ +import java.util.{List => JList, ArrayList => JArrayList, Collections} import scala.collection.JavaConversions._ import scala.io.Source @@ -10,25 +11,26 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast import spark._ import spark.rdd.PipedRDD -import java.util private[spark] class PythonRDD[T: ClassManifest]( - parent: RDD[T], - command: Seq[String], - envVars: java.util.Map[String, String], - preservePartitoning: Boolean, - pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) + parent: RDD[T], + command: Seq[String], + envVars: java.util.Map[String, String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) { // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + preservePartitoning: Boolean, pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) = this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, - broadcastVars) + broadcastVars, accumulator) override def splits = parent.splits @@ -93,18 +95,30 @@ private[spark] class PythonRDD[T: ClassManifest]( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(proc.getInputStream) return new Iterator[Array[Byte]] { - def next() = { + def next(): Array[Byte] = { val obj = _nextObj _nextObj = read() obj } - private def read() = { + private def read(): Array[Byte] = { try { val length = stream.readInt() - val obj = new Array[Byte](length) - stream.readFully(obj) - obj + if (length != -1) { + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + } else { + // We've finished the data section of the output, but we can still read some + // accumulator updates; let's do that, breaking when we get EOFException + while (true) { + val len2 = stream.readInt() + val update = new Array[Byte](len2) + stream.readFully(update) + accumulator += Collections.singletonList(update) + } + new Array[Byte](0) + } } catch { case eof: EOFException => { val exitStatus = proc.waitFor() @@ -246,3 +260,40 @@ private class ExtractValue extends spark.api.java.function.Function[(Array[Byte] private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } + +/** + * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it + * collects a list of pickled strings that we pass to Python through a socket. + */ +class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) + extends AccumulatorParam[JList[Array[Byte]]] { + + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList + + override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) + : JList[Array[Byte]] = { + if (serverHost == null) { + // This happens on the worker node, where we just want to remember all the updates + val1.addAll(val2) + val1 + } else { + // This happens on the master, where we pass the updates to Python through a socket + val socket = new Socket(serverHost, serverPort) + val in = socket.getInputStream + val out = new DataOutputStream(socket.getOutputStream) + out.writeInt(val2.size) + for (array <- val2) { + out.writeInt(array.length) + out.write(array) + } + out.flush() + // Wait for a byte from the Python side as an acknowledgement + val byteRead = in.read() + if (byteRead == -1) { + throw new SparkException("EOF reached before Python server acknowledged") + } + socket.close() + null + } + } +} diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index c595ae0842..00666bc0a3 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -7,6 +7,10 @@ Public classes: Main entry point for Spark functionality. - L{RDD} A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + - L{Broadcast} + A broadcast variable that gets reused across tasks. + - L{Accumulator} + An "add-only" shared variable that tasks can only add values to. """ import sys import os diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py new file mode 100644 index 0000000000..438af4cfc0 --- /dev/null +++ b/python/pyspark/accumulators.py @@ -0,0 +1,166 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> a = sc.accumulator(1) +>>> a.value +1 +>>> a.value = 2 +>>> a.value +2 +>>> a += 5 +>>> a.value +7 + +>>> rdd = sc.parallelize([1,2,3]) +>>> def f(x): +... global a +... a += x +>>> rdd.foreach(f) +>>> a.value +13 + +>>> class VectorAccumulatorParam(object): +... def zero(self, value): +... return [0.0] * len(value) +... def addInPlace(self, val1, val2): +... for i in xrange(len(val1)): +... val1[i] += val2[i] +... return val1 +>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) +>>> va.value +[1.0, 2.0, 3.0] +>>> def g(x): +... global va +... va += [x] * 3 +>>> rdd.foreach(g) +>>> va.value +[7.0, 8.0, 9.0] + +>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> def h(x): +... global a +... a.value = 7 +>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Exception:... +""" + +import struct +import SocketServer +import threading +from pyspark.cloudpickle import CloudPickler +from pyspark.serializers import read_int, read_with_length, load_pickle + + +# Holds accumulators registered on the current machine, keyed by ID. This is then used to send +# the local accumulator updates back to the driver program at the end of a task. +_accumulatorRegistry = {} + + +def _deserialize_accumulator(aid, zero_value, accum_param): + from pyspark.accumulators import _accumulatorRegistry + accum = Accumulator(aid, zero_value, accum_param) + accum._deserialized = True + _accumulatorRegistry[aid] = accum + return accum + + +class Accumulator(object): + def __init__(self, aid, value, accum_param): + """Create a new Accumulator with a given initial value and AccumulatorParam object""" + from pyspark.accumulators import _accumulatorRegistry + self.aid = aid + self.accum_param = accum_param + self._value = value + self._deserialized = False + _accumulatorRegistry[aid] = self + + def __reduce__(self): + """Custom serialization; saves the zero value from our AccumulatorParam""" + param = self.accum_param + return (_deserialize_accumulator, (self.aid, param.zero(self._value), param)) + + @property + def value(self): + """Get the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + return self._value + + @value.setter + def value(self, value): + """Sets the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + self._value = value + + def __iadd__(self, term): + """The += operator; adds a term to this accumulator's value""" + self._value = self.accum_param.addInPlace(self._value, term) + return self + + def __str__(self): + return str(self._value) + + +class AddingAccumulatorParam(object): + """ + An AccumulatorParam that uses the + operators to add values. Designed for simple types + such as integers, floats, and lists. Requires the zero value for the underlying type + as a parameter. + """ + + def __init__(self, zero_value): + self.zero_value = zero_value + + def zero(self, value): + return self.zero_value + + def addInPlace(self, value1, value2): + value1 += value2 + return value1 + + +# Singleton accumulator params for some standard types +INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) +DOUBLE_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) +COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) + + +class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + def handle(self): + from pyspark.accumulators import _accumulatorRegistry + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = load_pickle(read_with_length(self.rfile)) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + + +def _start_update_server(): + """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" + server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + return server + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e486f206b0..1e2f845f9c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -2,6 +2,8 @@ import os import atexit from tempfile import NamedTemporaryFile +from pyspark import accumulators +from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched @@ -22,6 +24,7 @@ class SparkContext(object): _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _takePartition = jvm.PythonRDD.takePartition + _next_accum_id = 0 def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -52,6 +55,14 @@ class SparkContext(object): self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, empty_string_array) + # Create a single Accumulator in Java that we'll send all our updates through; + # they will be passed back to us through a TCP server + self._accumulatorServer = accumulators._start_update_server() + (host, port) = self._accumulatorServer.server_address + self._javaAccumulator = self._jsc.accumulator( + self.jvm.java.util.ArrayList(), + self.jvm.PythonAccumulatorParam(host, port)) + self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have @@ -74,6 +85,8 @@ class SparkContext(object): def __del__(self): if self._jsc: self._jsc.stop() + if self._accumulatorServer: + self._accumulatorServer.shutdown() def stop(self): """ @@ -129,6 +142,31 @@ class SparkContext(object): return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + def accumulator(self, value, accum_param=None): + """ + Create an C{Accumulator} with the given initial value, using a given + AccumulatorParam helper object to define how to add values of the data + type if provided. Default AccumulatorParams are used for integers and + floating-point numbers if you do not provide one. For other types, the + AccumulatorParam must implement two methods: + - C{zero(value)}: provide a "zero value" for the type, compatible in + dimensions with the provided C{value} (e.g., a zero vector). + - C{addInPlace(val1, val2)}: add two values of the accumulator's data + type, returning a new value; for efficiency, can also update C{val1} + in place and return it. + """ + if accum_param == None: + if isinstance(value, int): + accum_param = accumulators.INT_ACCUMULATOR_PARAM + elif isinstance(value, float): + accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM + elif isinstance(value, complex): + accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM + else: + raise Exception("No default accumulator param for type %s" % type(value)) + SparkContext._next_accum_id += 1 + return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) + def addFile(self, path): """ Add a file to be downloaded into the working directory of this Spark diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1d36da42b0..d705f0f9e1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -703,7 +703,7 @@ class PipelinedRDD(RDD): env = MapConverter().convert(env, self.ctx.gateway._gateway_client) python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, class_manifest) + broadcast_vars, self.ctx._javaAccumulator, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 9a5151ea00..115cf28cc2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -52,8 +52,13 @@ def read_int(stream): raise EOFError return struct.unpack("!i", length)[0] + +def write_int(value, stream): + stream.write(struct.pack("!i", value)) + + def write_with_length(obj, stream): - stream.write(struct.pack("!i", len(obj))) + write_int(len(obj), stream) stream.write(obj) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 7e6ad3aa76..f6328c561f 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -1,7 +1,7 @@ """ An interactive shell. -This fle is designed to be launched as a PYTHONSTARTUP script. +This file is designed to be launched as a PYTHONSTARTUP script. """ import os from pyspark.context import SparkContext @@ -14,4 +14,4 @@ print "Spark context avaiable as sc." # which allows us to execute the user's PYTHONSTARTUP file: _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') if _pythonstartup and os.path.isfile(_pythonstartup): - execfile(_pythonstartup) + execfile(_pythonstartup) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3d792bbaa2..b2b9288089 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -5,9 +5,10 @@ import sys from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. +from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import write_with_length, read_with_length, \ +from pyspark.serializers import write_with_length, read_with_length, write_int, \ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file @@ -36,6 +37,10 @@ def main(): iterator = read_from_pickle_file(sys.stdin) for obj in func(split_index, iterator): write_with_length(dumps(obj), old_stdout) + # Mark the beginning of the accumulators section of the output + write_int(-1, old_stdout) + for aid, accum in _accumulatorRegistry.items(): + write_with_length(dump_pickle((aid, accum._value)), old_stdout) if __name__ == '__main__': -- cgit v1.2.3 From 7ed1bf4b485131d58ea6728e7247b79320aca9e6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 16 Jan 2013 19:15:14 -0800 Subject: Add RDD checkpointing to Python API. --- .../main/scala/spark/api/python/PythonRDD.scala | 3 -- python/epydoc.conf | 2 +- python/pyspark/context.py | 9 +++++ python/pyspark/rdd.py | 34 ++++++++++++++++ python/pyspark/tests.py | 46 ++++++++++++++++++++++ python/run-tests | 3 ++ 6 files changed, 93 insertions(+), 4 deletions(-) create mode 100644 python/pyspark/tests.py (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 89f7c316dc..8c38262dd8 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -135,8 +135,6 @@ private[spark] class PythonRDD[T: ClassManifest]( } } - override def checkpoint() { } - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } @@ -152,7 +150,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } - override def checkpoint() { } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } diff --git a/python/epydoc.conf b/python/epydoc.conf index 91ac984ba2..45102cd9fe 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -16,4 +16,4 @@ target: docs/ private: no exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers - pyspark.java_gateway pyspark.examples pyspark.shell + pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1e2f845f9c..a438b43fdc 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -195,3 +195,12 @@ class SparkContext(object): filename = path.split("/")[-1] os.environ["PYTHONPATH"] = \ "%s:%s" % (filename, os.environ["PYTHONPATH"]) + + def setCheckpointDir(self, dirName, useExisting=False): + """ + Set the directory under which RDDs are going to be checkpointed. This + method will create this directory and will throw an exception of the + path already exists (to avoid overwriting existing files may be + overwritten). The directory will be deleted on exit if indicated. + """ + self._jsc.sc().setCheckpointDir(dirName, useExisting) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d705f0f9e1..9b676cae4a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -49,6 +49,40 @@ class RDD(object): self._jrdd.cache() return self + def checkpoint(self): + """ + Mark this RDD for checkpointing. The RDD will be saved to a file inside + `checkpointDir` (set using setCheckpointDir()) and all references to + its parent RDDs will be removed. This is used to truncate very long + lineages. In the current implementation, Spark will save this RDD to + a file (using saveAsObjectFile()) after the first job using this RDD is + done. Hence, it is strongly recommended to use checkpoint() on RDDs + when + + (i) checkpoint() is called before the any job has been executed on this + RDD. + + (ii) This RDD has been made to persist in memory. Otherwise saving it + on a file will require recomputation. + """ + self._jrdd.rdd().checkpoint() + + def isCheckpointed(self): + """ + Return whether this RDD has been checkpointed or not + """ + return self._jrdd.rdd().isCheckpointed() + + def getCheckpointFile(self): + """ + Gets the name of the file to which this RDD was checkpointed + """ + checkpointFile = self._jrdd.rdd().getCheckpointFile() + if checkpointFile.isDefined(): + return checkpointFile.get() + else: + return None + # TODO persist(self, storageLevel) def map(self, f, preservesPartitioning=False): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py new file mode 100644 index 0000000000..c959d5dec7 --- /dev/null +++ b/python/pyspark/tests.py @@ -0,0 +1,46 @@ +""" +Unit tests for PySpark; additional tests are implemented as doctests in +individual modules. +""" +import atexit +import os +import shutil +from tempfile import NamedTemporaryFile +import time +import unittest + +from pyspark.context import SparkContext + + +class TestCheckpoint(unittest.TestCase): + + def setUp(self): + self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) + + def tearDown(self): + self.sc.stop() + + def test_basic_checkpointing(self): + checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(checkpointDir.name) + self.sc.setCheckpointDir(checkpointDir.name) + + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertIsNone(flatMappedRDD.getCheckpointFile()) + + flatMappedRDD.checkpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + self.assertEqual(checkpointDir.name, + os.path.dirname(flatMappedRDD.getCheckpointFile())) + + atexit.register(lambda: shutil.rmtree(checkpointDir.name)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/run-tests b/python/run-tests index 32470911f9..ce214e98a8 100755 --- a/python/run-tests +++ b/python/run-tests @@ -14,6 +14,9 @@ FAILED=$(($?||$FAILED)) $FWDIR/pyspark -m doctest pyspark/accumulators.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark -m unittest pyspark.tests +FAILED=$(($?||$FAILED)) + if [[ $FAILED != 0 ]]; then echo -en "\033[31m" # Red echo "Had test failures; see logs." -- cgit v1.2.3 From 5b6ea9e9a04994553d0319c541ca356e2e3064a7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 20 Jan 2013 15:31:41 -0800 Subject: Update checkpointing API docs in Python/Java. --- core/src/main/scala/spark/api/java/JavaRDDLike.scala | 17 +++++++---------- .../main/scala/spark/api/java/JavaSparkContext.scala | 17 +++++++++-------- python/pyspark/context.py | 11 +++++++---- python/pyspark/rdd.py | 17 +++++------------ 4 files changed, 28 insertions(+), 34 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 087270e46d..b3698ffa44 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -307,16 +307,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] JavaPairRDD.fromRDD(rdd.keyBy(f)) } - - /** - * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` - * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. - * This is used to truncate very long lineages. In the current implementation, Spark will save - * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. - * Hence, it is strongly recommended to use checkpoint() on RDDs when - * (i) checkpoint() is called before the any job has been executed on this RDD. - * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will - * require recomputation. + + /** + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. */ def checkpoint() = rdd.checkpoint() diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index fa2f14113d..14699961ad 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -357,20 +357,21 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists and useExisting is set to true, then the + * exisiting directory will be used. Otherwise an exception will be thrown to + * prevent accidental overriding of checkpoint files in the existing directory. */ def setCheckpointDir(dir: String, useExisting: Boolean) { sc.setCheckpointDir(dir, useExisting) } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists, an exception will be thrown to prevent accidental + * overriding of checkpoint files. */ def setCheckpointDir(dir: String) { sc.setCheckpointDir(dir) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 8beb8e2ae9..dcbed37270 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -202,9 +202,12 @@ class SparkContext(object): def setCheckpointDir(self, dirName, useExisting=False): """ - Set the directory under which RDDs are going to be checkpointed. This - method will create this directory and will throw an exception of the - path already exists (to avoid overwriting existing files may be - overwritten). The directory will be deleted on exit if indicated. + Set the directory under which RDDs are going to be checkpointed. The + directory must be a HDFS path if running on a cluster. + + If the directory does not exist, it will be created. If the directory + exists and C{useExisting} is set to true, then the exisiting directory + will be used. Otherwise an exception will be thrown to prevent + accidental overriding of checkpoint files in the existing directory. """ self._jsc.sc().setCheckpointDir(dirName, useExisting) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2a2ff9b271..7b6ab956ee 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -52,18 +52,11 @@ class RDD(object): def checkpoint(self): """ - Mark this RDD for checkpointing. The RDD will be saved to a file inside - `checkpointDir` (set using setCheckpointDir()) and all references to - its parent RDDs will be removed. This is used to truncate very long - lineages. In the current implementation, Spark will save this RDD to - a file (using saveAsObjectFile()) after the first job using this RDD is - done. Hence, it is strongly recommended to use checkpoint() on RDDs - when - - (i) checkpoint() is called before the any job has been executed on this - RDD. - - (ii) This RDD has been made to persist in memory. Otherwise saving it + Mark this RDD for checkpointing. It will be saved to a file inside the + checkpoint directory set with L{SparkContext.setCheckpointDir()} and + all references to its parent RDDs will be removed. This function must + be called before any job has been executed on this RDD. It is strongly + recommended that this RDD is persisted in memory, otherwise saving it on a file will require recomputation. """ self.is_checkpointed = True -- cgit v1.2.3 From 9f211dd3f0132daf72fb39883fa4b28e4fd547ca Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 14 Jan 2013 15:30:42 -0800 Subject: Fix PythonPartitioner equality; see SPARK-654. PythonPartitioner did not take the Python-side partitioning function into account when checking for equality, which might cause problems in the future. --- .../main/scala/spark/api/python/PythonPartitioner.scala | 13 +++++++++++-- core/src/main/scala/spark/api/python/PythonRDD.scala | 5 ----- python/pyspark/rdd.py | 17 +++++++++++------ 3 files changed, 22 insertions(+), 13 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala index 648d9402b0..519e310323 100644 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -6,8 +6,17 @@ import java.util.Arrays /** * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. + * + * Stores the unique id() of the Python-side partitioning function so that it is incorporated into + * equality comparisons. Correctness requires that the id is a unique identifier for the + * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning + * function). This can be ensured by using the Python id() function and maintaining a reference + * to the Python partitioning function so that its id() is not reused. */ -private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner { +private[spark] class PythonPartitioner( + override val numPartitions: Int, + val pyPartitionFunctionId: Long) + extends Partitioner { override def getPartition(key: Any): Int = { if (key == null) { @@ -32,7 +41,7 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends override def equals(other: Any): Boolean = other match { case h: PythonPartitioner => - h.numPartitions == numPartitions + h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId case _ => false } diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 89f7c316dc..e4c0530241 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -252,11 +252,6 @@ private object Pickle { val APPENDS: Byte = 'e' } -private class ExtractValue extends spark.api.java.function.Function[(Array[Byte], - Array[Byte]), Array[Byte]] { - override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2 -} - private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d705f0f9e1..b58bf24e3e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -33,6 +33,7 @@ class RDD(object): self._jrdd = jrdd self.is_cached = False self.ctx = ctx + self._partitionFunc = None @property def context(self): @@ -497,7 +498,7 @@ class RDD(object): return python_right_outer_join(self, other, numSplits) # TODO: add option to control map-side combining - def partitionBy(self, numSplits, hashFunc=hash): + def partitionBy(self, numSplits, partitionFunc=hash): """ Return a copy of the RDD partitioned using the specified partitioner. @@ -514,17 +515,21 @@ class RDD(object): def add_shuffle_key(split, iterator): buckets = defaultdict(list) for (k, v) in iterator: - buckets[hashFunc(k) % numSplits].append((k, v)) + buckets[partitionFunc(k) % numSplits].append((k, v)) for (split, items) in buckets.iteritems(): yield str(split) yield dump_pickle(Batch(items)) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) - jrdd = pairRDD.partitionBy(partitioner) - jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - return RDD(jrdd, self.ctx) + partitioner = self.ctx.jvm.PythonPartitioner(numSplits, + id(partitionFunc)) + jrdd = pairRDD.partitionBy(partitioner).values() + rdd = RDD(jrdd, self.ctx) + # This is required so that id(partitionFunc) remains unique, even if + # partitionFunc is a lambda: + rdd._partitionFunc = partitionFunc + return rdd # TODO: add control over map-side aggregation def combineByKey(self, createCombiner, mergeValue, mergeCombiners, -- cgit v1.2.3 From c0b9ceb8c3d56c6d6f6f6b5925c87abad06be646 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 00:23:53 -0800 Subject: Log remote lifecycle events in Akka for easier debugging --- core/src/main/scala/spark/util/AkkaUtils.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index e67cb0336d..fbd0ff46bf 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -32,6 +32,7 @@ private[spark] object AkkaUtils { akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] akka.actor.provider = "akka.remote.RemoteActorRefProvider" akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" + akka.remote.log-remote-lifecycle-events = on akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = %ds -- cgit v1.2.3 From 69a417858bf1627de5220d41afba64853d4bf64d Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 12:42:11 -0600 Subject: Also use hadoopConfiguration in newAPI methods. --- core/src/main/scala/spark/PairRDDFunctions.scala | 4 ++-- core/src/main/scala/spark/SparkContext.scala | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 51c15837c4..1c18736805 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -494,7 +494,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { - saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration) + saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass) } /** @@ -506,7 +506,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration) { + conf: Configuration = self.context.hadoopConfiguration) { val job = new NewAPIHadoopJob(conf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index f6b98c41bc..303e5081a4 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -293,8 +293,7 @@ class SparkContext( path, fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], - vm.erasure.asInstanceOf[Class[V]], - new Configuration(hadoopConfiguration)) + vm.erasure.asInstanceOf[Class[V]]) } /** @@ -306,7 +305,7 @@ class SparkContext( fClass: Class[F], kClass: Class[K], vClass: Class[V], - conf: Configuration): RDD[(K, V)] = { + conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration @@ -318,7 +317,7 @@ class SparkContext( * and extra configuration options to pass to the input format. */ def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( - conf: Configuration, + conf: Configuration = hadoopConfiguration, fClass: Class[F], kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = { -- cgit v1.2.3 From f116d6b5c6029c2f96160bd84829a6fe8b73cccf Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 18 Jan 2013 13:24:37 -0800 Subject: executor can use a different sparkHome from Worker --- core/src/main/scala/spark/deploy/DeployMessage.scala | 4 +++- core/src/main/scala/spark/deploy/JobDescription.scala | 5 ++++- core/src/main/scala/spark/deploy/client/TestClient.scala | 3 ++- core/src/main/scala/spark/deploy/master/Master.scala | 9 +++++---- core/src/main/scala/spark/deploy/worker/Worker.scala | 4 ++-- .../spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 3 ++- 6 files changed, 18 insertions(+), 10 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 457122745b..7ee3e63429 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -5,6 +5,7 @@ import spark.deploy.master.{WorkerInfo, JobInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List import scala.collection.mutable.HashMap +import java.io.File private[spark] sealed trait DeployMessage extends Serializable @@ -42,7 +43,8 @@ private[spark] case class LaunchExecutor( execId: Int, jobDesc: JobDescription, cores: Int, - memory: Int) + memory: Int, + sparkHome: File) extends DeployMessage diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala index 20879c5f11..7f8f9af417 100644 --- a/core/src/main/scala/spark/deploy/JobDescription.scala +++ b/core/src/main/scala/spark/deploy/JobDescription.scala @@ -1,10 +1,13 @@ package spark.deploy +import java.io.File + private[spark] class JobDescription( val name: String, val cores: Int, val memoryPerSlave: Int, - val command: Command) + val command: Command, + val sparkHome: File) extends Serializable { val user = System.getProperty("user.name", "") diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index 57a7e123b7..dc743b1fbf 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -3,6 +3,7 @@ package spark.deploy.client import spark.util.AkkaUtils import spark.{Logging, Utils} import spark.deploy.{Command, JobDescription} +import java.io.File private[spark] object TestClient { @@ -25,7 +26,7 @@ private[spark] object TestClient { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val desc = new JobDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map())) + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), new File("dummy-spark-home")) val listener = new TestListener val client = new Client(actorSystem, url, desc, listener) client.start() diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 6ecebe626a..f0bee67159 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -6,6 +6,7 @@ import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, Remote import java.text.SimpleDateFormat import java.util.Date +import java.io.File import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -173,7 +174,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor for (pos <- 0 until numUsable) { if (assigned(pos) > 0) { val exec = job.addExecutor(usableWorkers(pos), assigned(pos)) - launchExecutor(usableWorkers(pos), exec) + launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome) job.state = JobState.RUNNING } } @@ -186,7 +187,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor val coresToUse = math.min(worker.coresFree, job.coresLeft) if (coresToUse > 0) { val exec = job.addExecutor(worker, coresToUse) - launchExecutor(worker, exec) + launchExecutor(worker, exec, job.desc.sparkHome) job.state = JobState.RUNNING } } @@ -195,10 +196,10 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } } - def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) { + def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: File) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory) + worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 7c9e588ea2..078b2d8037 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -119,10 +119,10 @@ private[spark] class Worker( logError("Worker registration failed: " + message) System.exit(1) - case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_) => + case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name)) val manager = new ExecutorRunner( - jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, sparkHome, workDir) + jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, execSparkHome_, workDir) executors(jobId + "/" + execId) = manager manager.start() coresUsed += cores_ diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index e2301347e5..0dcc2efaca 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -4,6 +4,7 @@ import spark.{Utils, Logging, SparkContext} import spark.deploy.client.{Client, ClientListener} import spark.deploy.{Command, JobDescription} import scala.collection.mutable.HashMap +import java.io.File private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, @@ -39,7 +40,7 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.sparkHome)) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() -- cgit v1.2.3 From aae5a920a4db0c31918a65a03ce7d2087826fd65 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 18 Jan 2013 13:28:50 -0800 Subject: get sparkHome the correct way --- .../scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 0dcc2efaca..08b9d6ff47 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -40,7 +40,7 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.sparkHome)) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.getSparkHome())) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() -- cgit v1.2.3 From 5bf73df7f08b17719711a5f05f0b3390b4951272 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sat, 19 Jan 2013 13:26:15 -0800 Subject: oops, fix stupid compile error --- .../scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 08b9d6ff47..94886d3941 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -40,7 +40,8 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.getSparkHome())) + val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sparkHome)) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() -- cgit v1.2.3 From c73107500e0a5b6c5f0b4aba8c4504ee4c2adbaf Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sun, 20 Jan 2013 21:55:50 -0800 Subject: send sparkHome as String instead of File over network --- core/src/main/scala/spark/deploy/DeployMessage.scala | 2 +- core/src/main/scala/spark/deploy/master/Master.scala | 2 +- core/src/main/scala/spark/deploy/worker/Worker.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 7ee3e63429..a4081ef89c 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -44,7 +44,7 @@ private[spark] case class LaunchExecutor( jobDesc: JobDescription, cores: Int, memory: Int, - sparkHome: File) + sparkHome: String) extends DeployMessage diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index f0bee67159..1b6f808a51 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -199,7 +199,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: File) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) + worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome.getAbsolutePath) exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 078b2d8037..19bf2be118 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -122,7 +122,7 @@ private[spark] class Worker( case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name)) val manager = new ExecutorRunner( - jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, execSparkHome_, workDir) + jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir) executors(jobId + "/" + execId) = manager manager.start() coresUsed += cores_ -- cgit v1.2.3 From fe26acc482f358bf87700f5e80160f7ce558cea7 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sun, 20 Jan 2013 21:57:44 -0800 Subject: remove unused imports --- core/src/main/scala/spark/deploy/DeployMessage.scala | 2 -- 1 file changed, 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index a4081ef89c..35f40c6e91 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -4,8 +4,6 @@ import spark.deploy.ExecutorState.ExecutorState import spark.deploy.master.{WorkerInfo, JobInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List -import scala.collection.mutable.HashMap -import java.io.File private[spark] sealed trait DeployMessage extends Serializable -- cgit v1.2.3 From a3f571b539ffd126e9f3bc3e9c7bedfcb6f4d2d4 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 21 Jan 2013 10:52:17 -0800 Subject: more File -> String changes --- core/src/main/scala/spark/deploy/JobDescription.scala | 4 +--- core/src/main/scala/spark/deploy/client/TestClient.scala | 3 +-- core/src/main/scala/spark/deploy/master/Master.scala | 5 ++--- .../scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 4 +--- 4 files changed, 5 insertions(+), 11 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala index 7f8f9af417..7160fc05fc 100644 --- a/core/src/main/scala/spark/deploy/JobDescription.scala +++ b/core/src/main/scala/spark/deploy/JobDescription.scala @@ -1,13 +1,11 @@ package spark.deploy -import java.io.File - private[spark] class JobDescription( val name: String, val cores: Int, val memoryPerSlave: Int, val command: Command, - val sparkHome: File) + val sparkHome: String) extends Serializable { val user = System.getProperty("user.name", "") diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index dc743b1fbf..8764c400e2 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -3,7 +3,6 @@ package spark.deploy.client import spark.util.AkkaUtils import spark.{Logging, Utils} import spark.deploy.{Command, JobDescription} -import java.io.File private[spark] object TestClient { @@ -26,7 +25,7 @@ private[spark] object TestClient { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val desc = new JobDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), new File("dummy-spark-home")) + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home") val listener = new TestListener val client = new Client(actorSystem, url, desc, listener) client.start() diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 1b6f808a51..2c2cd0231b 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -6,7 +6,6 @@ import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, Remote import java.text.SimpleDateFormat import java.util.Date -import java.io.File import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -196,10 +195,10 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } } - def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: File) { + def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome.getAbsolutePath) + worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 94886d3941..a21a5b2f3d 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -3,8 +3,6 @@ package spark.scheduler.cluster import spark.{Utils, Logging, SparkContext} import spark.deploy.client.{Client, ClientListener} import spark.deploy.{Command, JobDescription} -import scala.collection.mutable.HashMap -import java.io.File private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, @@ -41,7 +39,7 @@ private[spark] class SparkDeploySchedulerBackend( val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sparkHome)) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() -- cgit v1.2.3 From 4d34c7fc3ecd7a4d035005f84c01e6990c0c345e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 11:33:48 -0800 Subject: Fix compile error caused by cherry-pick --- .../main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index a21a5b2f3d..4f82cd96dd 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -3,6 +3,7 @@ package spark.scheduler.cluster import spark.{Utils, Logging, SparkContext} import spark.deploy.client.{Client, ClientListener} import spark.deploy.{Command, JobDescription} +import scala.collection.mutable.HashMap private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, -- cgit v1.2.3 From a88b44ed3b670633549049e9ccf990ea455e9720 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 11:59:21 -0800 Subject: Only bind to IPv4 addresses when trying to auto-detect external IP --- core/src/main/scala/spark/Utils.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index b3421df27c..692a3f4050 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,7 +1,7 @@ package spark import java.io._ -import java.net.{NetworkInterface, InetAddress, URL, URI} +import java.net.{NetworkInterface, InetAddress, Inet4Address, URL, URI} import java.util.{Locale, Random, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration @@ -251,7 +251,8 @@ private object Utils extends Logging { // Address resolves to something like 127.0.1.1, which happens on Debian; try to find // a better address using the local network interfaces for (ni <- NetworkInterface.getNetworkInterfaces) { - for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress) { + for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && + !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { // We've found an address that looks reasonable! logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + -- cgit v1.2.3 From ffd1623595cdce4080ad1e4e676e65898ebdd6dd Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 15:55:46 -0600 Subject: Minor cleanup. --- core/src/main/scala/spark/Accumulators.scala | 3 +-- core/src/main/scala/spark/Logging.scala | 3 +-- core/src/main/scala/spark/ParallelCollection.scala | 15 +++++---------- core/src/main/scala/spark/TaskContext.scala | 3 +-- core/src/main/scala/spark/rdd/BlockRDD.scala | 6 ++---- core/src/main/scala/spark/rdd/CartesianRDD.scala | 3 +-- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 6 ++---- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 6 ++---- core/src/main/scala/spark/rdd/SampledRDD.scala | 5 ++--- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 3 +-- core/src/main/scala/spark/rdd/UnionRDD.scala | 3 +-- core/src/main/scala/spark/rdd/ZippedRDD.scala | 3 +-- .../scala/spark/scheduler/local/LocalScheduler.scala | 4 ++-- .../scheduler/mesos/CoarseMesosSchedulerBackend.scala | 16 ++++++---------- .../spark/scheduler/mesos/MesosSchedulerBackend.scala | 10 +++------- core/src/test/scala/spark/FileServerSuite.scala | 4 ++-- 16 files changed, 33 insertions(+), 60 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index b644aba5f8..57c6df35be 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -25,8 +25,7 @@ class Accumulable[R, T] ( extends Serializable { val id = Accumulators.newId - @transient - private var value_ = initialValue // Current value on master + @transient private var value_ = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers var deserialized = false diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala index 90bae26202..7c1c1bb144 100644 --- a/core/src/main/scala/spark/Logging.scala +++ b/core/src/main/scala/spark/Logging.scala @@ -11,8 +11,7 @@ import org.slf4j.LoggerFactory trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine - @transient - private var log_ : Logger = null + @transient private var log_ : Logger = null // Method to get or create the logger for this object protected def log: Logger = { diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index ede933c9e9..ad23e5bec8 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -23,32 +23,28 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - @transient sc : SparkContext, + @transient sc: SparkContext, @transient data: Seq[T], numSlices: Int, - locationPrefs : Map[Int,Seq[String]]) + locationPrefs: Map[Int,Seq[String]]) extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split // instead. // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val slices = ParallelCollection.slice(data, numSlices).toArray slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray } - override def getSplits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_ override def compute(s: Split, context: TaskContext) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def getPreferredLocations(s: Split): Seq[String] = { - locationPrefs.get(s.index) match { - case Some(s) => s - case _ => Nil - } + locationPrefs.get(s.index) getOrElse Nil } override def clearDependencies() { @@ -56,7 +52,6 @@ private[spark] class ParallelCollection[T: ClassManifest]( } } - private object ParallelCollection { /** * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala index d2746b26b3..eab85f85a2 100644 --- a/core/src/main/scala/spark/TaskContext.scala +++ b/core/src/main/scala/spark/TaskContext.scala @@ -5,8 +5,7 @@ import scala.collection.mutable.ArrayBuffer class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable { - @transient - val onCompleteCallbacks = new ArrayBuffer[() => Unit] + @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit] // Add a callback function to be executed on task completion. An example use // is for HadoopRDD to register a callback to close the input stream. diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index b1095a52b4..2c022f88e0 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -11,13 +11,11 @@ private[spark] class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) extends RDD[T](sc, Nil) { - @transient - var splits_ : Array[Split] = (0 until blockIds.size).map(i => { + @transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => { new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] }).toArray - @transient - lazy val locations_ = { + @transient lazy val locations_ = { val blockManager = SparkEnv.get.blockManager /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ val locations = blockManager.getLocations(blockIds) diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 79e7c24e7c..453d410ad4 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -35,8 +35,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( val numSplitsInRdd2 = rdd2.splits.size - @transient - var splits_ = { + @transient var splits_ = { // create the cross product split val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 1d528be2aa..8fafd27bb6 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -45,8 +45,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) val aggr = new CoGroupAggregator - @transient - var deps_ = { + @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { if (rdd.partitioner == Some(part)) { @@ -63,8 +62,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) override def getDependencies = deps_ - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val array = new Array[Split](part.numPartitions) for (i <- 0 until array.size) { array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) => diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index bb22db073c..c3b155fcbd 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -37,11 +37,9 @@ class NewHadoopRDD[K, V]( formatter.format(new Date()) } - @transient - private val jobId = new JobID(jobtrackerId, id) + @transient private val jobId = new JobID(jobtrackerId, id) - @transient - private val splits_ : Array[Split] = { + @transient private val splits_ : Array[Split] = { val inputFormat = inputFormatClass.newInstance val jobContext = newJobContext(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 1bc9c96112..e24ad23b21 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -19,13 +19,12 @@ class SampledRDD[T: ClassManifest]( seed: Int) extends RDD[T](prev) { - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val rg = new Random(seed) firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } - override def getSplits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_ override def getPreferredLocations(split: Split) = firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 1b219473e0..28ff19876d 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -22,8 +22,7 @@ class ShuffledRDD[K, V]( override val partitioner = Some(part) - @transient - var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) + @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) override def getSplits = splits_ diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 24a085df02..82f0a44ecd 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -28,8 +28,7 @@ class UnionRDD[T: ClassManifest]( @transient var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index 16e6cc0f1b..d950b06c85 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -34,8 +34,7 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( // TODO: FIX THIS. - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { if (rdd1.splits.size != rdd2.splits.size) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index dff550036d..21d255debd 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -19,8 +19,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon extends TaskScheduler with Logging { - var attemptId = new AtomicInteger(0) - var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) + val attemptId = new AtomicInteger(0) + val threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) val env = SparkEnv.get var listener: TaskSchedulerListener = null diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index c45c7df69c..014906b028 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -64,13 +64,9 @@ private[spark] class CoarseMesosSchedulerBackend( val taskIdToSlaveId = new HashMap[Int, String] val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed - val sparkHome = sc.getSparkHome() match { - case Some(path) => - path - case None => - throw new SparkException("Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor") - } + val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( + "Spark home is not set; set it through the spark.home system " + + "property, the SPARK_HOME environment variable or the SparkContext constructor")) val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt @@ -184,7 +180,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Helper function to pull out a resource from a Mesos Resources protobuf */ - def getResource(res: JList[Resource], name: String): Double = { + private def getResource(res: JList[Resource], name: String): Double = { for (r <- res if r.getName == name) { return r.getScalar.getValue } @@ -193,7 +189,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Build a Mesos resource protobuf object */ - def createResource(resourceName: String, quantity: Double): Protos.Resource = { + private def createResource(resourceName: String, quantity: Double): Protos.Resource = { Resource.newBuilder() .setName(resourceName) .setType(Value.Type.SCALAR) @@ -202,7 +198,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Check whether a Mesos task state represents a finished task */ - def isFinished(state: MesosTaskState) = { + private def isFinished(state: MesosTaskState) = { state == MesosTaskState.TASK_FINISHED || state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_KILLED || diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index 8c7a1dfbc0..2989e31f5e 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -76,13 +76,9 @@ private[spark] class MesosSchedulerBackend( } def createExecutorInfo(): ExecutorInfo = { - val sparkHome = sc.getSparkHome() match { - case Some(path) => - path - case None => - throw new SparkException("Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor") - } + val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( + "Spark home is not set; set it through the spark.home system " + + "property, the SPARK_HOME environment variable or the SparkContext constructor")) val execScript = new File(sparkHome, "spark-executor").getCanonicalPath val environment = Environment.newBuilder() sc.executorEnvs.foreach { case (key, value) => diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala index b4283d9604..fe964bd893 100644 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -9,8 +9,8 @@ import SparkContext._ class FileServerSuite extends FunSuite with BeforeAndAfter { @transient var sc: SparkContext = _ - @transient var tmpFile : File = _ - @transient var testJarFile : File = _ + @transient var tmpFile: File = _ + @transient var testJarFile: File = _ before { // Create a sample text file -- cgit v1.2.3 From ef711902c1f42db14c8ddd524195f0a9efb56e65 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Jan 2013 16:42:24 -0800 Subject: Don't download files to master's working directory. This should avoid exceptions caused by existing files with different contents. I also removed some unused code. --- core/src/main/scala/spark/HttpFileServer.scala | 8 ++--- core/src/main/scala/spark/SparkContext.scala | 7 ++-- core/src/main/scala/spark/SparkEnv.scala | 20 +++++++---- core/src/main/scala/spark/SparkFiles.java | 25 ++++++++++++++ core/src/main/scala/spark/Utils.scala | 16 +-------- .../scala/spark/api/java/JavaSparkContext.scala | 5 +-- .../main/scala/spark/api/python/PythonRDD.scala | 2 ++ .../scala/spark/deploy/worker/ExecutorRunner.scala | 5 --- core/src/main/scala/spark/executor/Executor.scala | 6 ++-- .../spark/scheduler/local/LocalScheduler.scala | 6 ++-- core/src/test/scala/spark/FileServerSuite.scala | 9 +++-- python/pyspark/__init__.py | 5 ++- python/pyspark/context.py | 40 +++++++++++++++++++--- python/pyspark/files.py | 24 +++++++++++++ python/pyspark/worker.py | 3 ++ python/run-tests | 3 ++ 16 files changed, 133 insertions(+), 51 deletions(-) create mode 100644 core/src/main/scala/spark/SparkFiles.java create mode 100644 python/pyspark/files.py (limited to 'core') diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala index 659d17718f..00901d95e2 100644 --- a/core/src/main/scala/spark/HttpFileServer.scala +++ b/core/src/main/scala/spark/HttpFileServer.scala @@ -1,9 +1,7 @@ package spark -import java.io.{File, PrintWriter} -import java.net.URL -import scala.collection.mutable.HashMap -import org.apache.hadoop.fs.FileUtil +import java.io.{File} +import com.google.common.io.Files private[spark] class HttpFileServer extends Logging { @@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging { } def addFileToDir(file: File, dir: File) : String = { - Utils.copyFile(file, new File(dir, file.getName)) + Files.copy(file, new File(dir, file.getName)) return dir + "/" + file.getName } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 8b6f4b3b7d..2eeca66ed6 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -439,9 +439,10 @@ class SparkContext( def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** - * Add a file to be downloaded into the working directory of this Spark job on every node. + * Add a file to be downloaded with this Spark job on every node. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(path)` to find its download location. */ def addFile(path: String) { val uri = new URI(path) @@ -454,7 +455,7 @@ class SparkContext( // Fetch the file locally in case a job is executed locally. // Jobs that run through LocalScheduler will already fetch the required dependencies, // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. - Utils.fetchFile(path, new File(".")) + Utils.fetchFile(path, new File(SparkFiles.getRootDirectory)) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 41441720a7..6b44e29f4c 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -28,14 +28,10 @@ class SparkEnv ( val broadcastManager: BroadcastManager, val blockManager: BlockManager, val connectionManager: ConnectionManager, - val httpFileServer: HttpFileServer + val httpFileServer: HttpFileServer, + val sparkFilesDir: String ) { - /** No-parameter constructor for unit tests. */ - def this() = { - this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null) - } - def stop() { httpFileServer.stop() mapOutputTracker.stop() @@ -112,6 +108,15 @@ object SparkEnv extends Logging { httpFileServer.initialize() System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) + // Set the sparkFiles directory, used when downloading dependencies. In local mode, + // this is a temporary directory; in distributed mode, this is the executor's current working + // directory. + val sparkFilesDir: String = if (isMaster) { + Utils.createTempDir().getAbsolutePath + } else { + "." + } + // Warn about deprecated spark.cache.class property if (System.getProperty("spark.cache.class") != null) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -128,6 +133,7 @@ object SparkEnv extends Logging { broadcastManager, blockManager, connectionManager, - httpFileServer) + httpFileServer, + sparkFilesDir) } } diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java new file mode 100644 index 0000000000..b59d8ce93f --- /dev/null +++ b/core/src/main/scala/spark/SparkFiles.java @@ -0,0 +1,25 @@ +package spark; + +import java.io.File; + +/** + * Resolves paths to files added through `addFile(). + */ +public class SparkFiles { + + private SparkFiles() {} + + /** + * Get the absolute path of a file added through `addFile()`. + */ + public static String get(String filename) { + return new File(getRootDirectory(), filename).getAbsolutePath(); + } + + /** + * Get the root directory that contains files added through `addFile()`. + */ + public static String getRootDirectory() { + return SparkEnv.get().sparkFilesDir(); + } +} \ No newline at end of file diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 692a3f4050..827c8bd81e 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -111,20 +111,6 @@ private object Utils extends Logging { } } - /** Copy a file on the local file system */ - def copyFile(source: File, dest: File) { - val in = new FileInputStream(source) - val out = new FileOutputStream(dest) - copyStream(in, out, true) - } - - /** Download a file from a given URL to the local filesystem */ - def downloadFile(url: URL, localPath: String) { - val in = url.openStream() - val out = new FileOutputStream(localPath) - Utils.copyStream(in, out, true) - } - /** * Download a file requested by the executor. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. @@ -201,7 +187,7 @@ private object Utils extends Logging { Utils.execute(Seq("tar", "-xf", filename), targetDir) } // Make the file executable - That's necessary for scripts - FileUtil.chmod(filename, "a+x") + FileUtil.chmod(targetFile.getAbsolutePath, "a+x") } /** diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index 16c122c584..50b8970cd8 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -323,9 +323,10 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def getSparkHome(): Option[String] = sc.getSparkHome() /** - * Add a file to be downloaded into the working directory of this Spark job on every node. + * Add a file to be downloaded with this Spark job on every node. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(path)` to find its download location. */ def addFile(path: String) { sc.addFile(path) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 5526406a20..f43a152ca7 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -67,6 +67,8 @@ private[spark] class PythonRDD[T: ClassManifest]( val dOut = new DataOutputStream(proc.getOutputStream) // Split index dOut.writeInt(split.index) + // sparkFilesDir + PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut) // Broadcast variables dOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index beceb55ecd..0d1fe2a6b4 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -106,11 +106,6 @@ private[spark] class ExecutorRunner( throw new IOException("Failed to create directory " + executorDir) } - // Download the files it depends on into it (disabled for now) - //for (url <- jobDesc.fileUrls) { - // fetchFile(url, executorDir) - //} - // Launch the process val command = buildCommandSeq() val builder = new ProcessBuilder(command: _*).directory(executorDir) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 2552958d27..70629f6003 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -162,16 +162,16 @@ private[spark] class Executor extends Logging { // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last - val url = new File(".", localName).toURI.toURL + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL if (!urlClassLoader.getURLs.contains(url)) { logInfo("Adding " + url + " to class loader") urlClassLoader.addURL(url) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index dff550036d..4451d314e6 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -116,16 +116,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last - val url = new File(".", localName).toURI.toURL + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL if (!classLoader.getURLs.contains(url)) { logInfo("Adding " + url + " to class loader") classLoader.addURL(url) diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala index b4283d9604..528c6b8424 100644 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -40,7 +40,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { sc.addFile(tmpFile.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { - val in = new BufferedReader(new FileReader("FileServerSuite.txt")) + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) val fileVal = in.readLine().toInt in.close() _ * fileVal + _ * fileVal @@ -54,7 +55,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { sc.addFile((new File(tmpFile.toString)).toURL.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { - val in = new BufferedReader(new FileReader("FileServerSuite.txt")) + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) val fileVal = in.readLine().toInt in.close() _ * fileVal + _ * fileVal @@ -83,7 +85,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { sc.addFile(tmpFile.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { - val in = new BufferedReader(new FileReader("FileServerSuite.txt")) + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) val fileVal = in.readLine().toInt in.close() _ * fileVal + _ * fileVal diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 00666bc0a3..3e8bca62f0 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -11,6 +11,8 @@ Public classes: A broadcast variable that gets reused across tasks. - L{Accumulator} An "add-only" shared variable that tasks can only add values to. + - L{SparkFiles} + Access files shipped with jobs. """ import sys import os @@ -19,6 +21,7 @@ sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.eg from pyspark.context import SparkContext from pyspark.rdd import RDD +from pyspark.files import SparkFiles -__all__ = ["SparkContext", "RDD"] +__all__ = ["SparkContext", "RDD", "SparkFiles"] diff --git a/python/pyspark/context.py b/python/pyspark/context.py index dcbed37270..ec0cc7c2f9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1,5 +1,7 @@ import os import atexit +import shutil +import tempfile from tempfile import NamedTemporaryFile from pyspark import accumulators @@ -173,10 +175,26 @@ class SparkContext(object): def addFile(self, path): """ - Add a file to be downloaded into the working directory of this Spark - job on every node. The C{path} passed can be either a local file, - a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, - HTTPS or FTP URI. + Add a file to be downloaded with this Spark job on every node. + The C{path} passed can be either a local file, a file in HDFS + (or other Hadoop-supported filesystems), or an HTTP, HTTPS or + FTP URI. + + To access the file in Spark jobs, use + L{SparkFiles.get(path)} to find its + download location. + + >>> from pyspark import SparkFiles + >>> path = os.path.join(tempdir, "test.txt") + >>> with open(path, "w") as testFile: + ... testFile.write("100") + >>> sc.addFile(path) + >>> def func(iterator): + ... with open(SparkFiles.get("test.txt")) as testFile: + ... fileVal = int(testFile.readline()) + ... return [x * 100 for x in iterator] + >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() + [100, 200, 300, 400] """ self._jsc.sc().addFile(path) @@ -211,3 +229,17 @@ class SparkContext(object): accidental overriding of checkpoint files in the existing directory. """ self._jsc.sc().setCheckpointDir(dirName, useExisting) + + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['tempdir'] = tempfile.mkdtemp() + atexit.register(lambda: shutil.rmtree(globs['tempdir'])) + doctest.testmod(globs=globs) + globs['sc'].stop() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/files.py b/python/pyspark/files.py new file mode 100644 index 0000000000..de1334f046 --- /dev/null +++ b/python/pyspark/files.py @@ -0,0 +1,24 @@ +import os + + +class SparkFiles(object): + """ + Resolves paths to files added through + L{addFile()}. + + SparkFiles contains only classmethods; users should not create SparkFiles + instances. + """ + + _root_directory = None + + def __init__(self): + raise NotImplementedError("Do not construct SparkFiles objects") + + @classmethod + def get(cls, filename): + """ + Get the absolute path of a file added through C{addFile()}. + """ + path = os.path.join(SparkFiles._root_directory, filename) + return os.path.abspath(path) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b2b9288089..e7bdb7682b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -8,6 +8,7 @@ from base64 import standard_b64decode from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler +from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file @@ -23,6 +24,8 @@ def load_obj(): def main(): split_index = read_int(sys.stdin) + spark_files_dir = load_pickle(read_with_length(sys.stdin)) + SparkFiles._root_directory = spark_files_dir num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): bid = read_long(sys.stdin) diff --git a/python/run-tests b/python/run-tests index ce214e98a8..a3a9ff5dcb 100755 --- a/python/run-tests +++ b/python/run-tests @@ -8,6 +8,9 @@ FAILED=0 $FWDIR/pyspark pyspark/rdd.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark pyspark/context.py +FAILED=$(($?||$FAILED)) + $FWDIR/pyspark -m doctest pyspark/broadcast.py FAILED=$(($?||$FAILED)) -- cgit v1.2.3 From 7b9e96c99206c0679d9925e0161fde738a5c7c3a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Jan 2013 16:45:00 -0800 Subject: Add synchronization to Executor.updateDependencies() (SPARK-662) --- core/src/main/scala/spark/executor/Executor.scala | 34 ++++++++++++----------- 1 file changed, 18 insertions(+), 16 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 70629f6003..28d9d40d43 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -159,22 +159,24 @@ private[spark] class Executor extends Logging { * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentFiles(name) = timestamp - } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!urlClassLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - urlClassLoader.addURL(url) + synchronized { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL + if (!urlClassLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + urlClassLoader.addURL(url) + } } } } -- cgit v1.2.3 From 2d8218b8717435a47d7cea399290b30bf5ef010b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 20:00:27 -0600 Subject: Remove unneeded/now-broken saveAsNewAPIHadoopFile overload. --- core/src/main/scala/spark/PairRDDFunctions.scala | 12 ------------ 1 file changed, 12 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 656b820b8a..53b051f1c5 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -485,18 +485,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } - /** - * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` - * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. - */ - def saveAsNewAPIHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { - saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass) - } - /** * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. -- cgit v1.2.3 From a8baeb93272b03a98e44c7bf5c541611aec4a64b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 21:30:24 -0600 Subject: Further simplify getOrElse call. --- core/src/main/scala/spark/ParallelCollection.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index ad23e5bec8..10adcd53ec 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -44,7 +44,7 @@ private[spark] class ParallelCollection[T: ClassManifest]( s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def getPreferredLocations(s: Split): Seq[String] = { - locationPrefs.get(s.index) getOrElse Nil + locationPrefs.getOrElse(s.index, Nil) } override def clearDependencies() { -- cgit v1.2.3 From 551a47a620c7dc207e3530e54d794a3c3aa8e45e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Jan 2013 23:31:00 -0800 Subject: Refactor daemon thread pool creation. --- .../src/main/scala/spark/DaemonThreadFactory.scala | 18 ------------ core/src/main/scala/spark/Utils.scala | 33 +++++----------------- .../scala/spark/network/ConnectionManager.scala | 5 ++-- .../spark/scheduler/local/LocalScheduler.scala | 2 +- .../spark/streaming/dstream/RawInputDStream.scala | 5 ++-- 5 files changed, 13 insertions(+), 50 deletions(-) delete mode 100644 core/src/main/scala/spark/DaemonThreadFactory.scala (limited to 'core') diff --git a/core/src/main/scala/spark/DaemonThreadFactory.scala b/core/src/main/scala/spark/DaemonThreadFactory.scala deleted file mode 100644 index 56e59adeb7..0000000000 --- a/core/src/main/scala/spark/DaemonThreadFactory.scala +++ /dev/null @@ -1,18 +0,0 @@ -package spark - -import java.util.concurrent.ThreadFactory - -/** - * A ThreadFactory that creates daemon threads - */ -private object DaemonThreadFactory extends ThreadFactory { - override def newThread(r: Runnable): Thread = new DaemonThread(r) -} - -private class DaemonThread(r: Runnable = null) extends Thread { - override def run() { - if (r != null) { - r.run() - } - } -} \ No newline at end of file diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 692a3f4050..9b8636f6c8 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -10,6 +10,7 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.io.Source import com.google.common.io.Files +import com.google.common.util.concurrent.ThreadFactoryBuilder /** * Various utility methods used by Spark. @@ -287,29 +288,14 @@ private object Utils extends Logging { customHostname.getOrElse(InetAddress.getLocalHost.getHostName) } - /** - * Returns a standard ThreadFactory except all threads are daemons. - */ - private def newDaemonThreadFactory: ThreadFactory = { - new ThreadFactory { - def newThread(r: Runnable): Thread = { - var t = Executors.defaultThreadFactory.newThread (r) - t.setDaemon (true) - return t - } - } - } + private[spark] val daemonThreadFactory: ThreadFactory = + new ThreadFactoryBuilder().setDaemon(true).build() /** * Wrapper over newCachedThreadPool. */ - def newDaemonCachedThreadPool(): ThreadPoolExecutor = { - var threadPool = Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory (newDaemonThreadFactory) - - return threadPool - } + def newDaemonCachedThreadPool(): ThreadPoolExecutor = + Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] /** * Return the string to tell how long has passed in seconds. The passing parameter should be in @@ -322,13 +308,8 @@ private object Utils extends Logging { /** * Wrapper over newFixedThreadPool. */ - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { - var threadPool = Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory(newDaemonThreadFactory) - - return threadPool - } + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = + Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] /** * Delete a file or directory and its contents recursively. diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 36c01ad629..2ecd14f536 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -52,9 +52,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] val sendMessageRequests = new Queue[(Message, SendingConnection)] - implicit val futureExecContext = ExecutionContext.fromExecutor( - Executors.newCachedThreadPool(DaemonThreadFactory)) - + implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) + var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null serverChannel.configureBlocking(false) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index dff550036d..87f8474ea0 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -20,7 +20,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon with Logging { var attemptId = new AtomicInteger(0) - var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) + var threadPool = Utils.newDaemonFixedThreadPool(threads) val env = SparkEnv.get var listener: TaskSchedulerListener = null diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala index 290fab1ce0..04e6b69b7b 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala @@ -1,6 +1,6 @@ package spark.streaming.dstream -import spark.{DaemonThread, Logging} +import spark.Logging import spark.storage.StorageLevel import spark.streaming.StreamingContext @@ -48,7 +48,8 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel) val queue = new ArrayBlockingQueue[ByteBuffer](2) - blockPushingThread = new DaemonThread { + blockPushingThread = new Thread { + setDaemon(true) override def run() { var nextBlockNumber = 0 while (true) { -- cgit v1.2.3 From e353886a8ca6179f25b4176d7a62b5d04ce79276 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 22 Jan 2013 00:23:31 -0800 Subject: Use generation numbers for fetch failure tracking --- .../main/scala/spark/scheduler/DAGScheduler.scala | 27 ++++++++++++++++------ 1 file changed, 20 insertions(+), 7 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 59f2099e91..39a1e6d6c6 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -72,8 +72,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val cacheTracker = env.cacheTracker val mapOutputTracker = env.mapOutputTracker - val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; - // that's not going to be a realistic assumption in general + // For tracking failed nodes, we use the MapOutputTracker's generation number, which is + // sent with every task. When we detect a node failing, we note the current generation number + // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask + // results. + // TODO: Garbage collect information about failure generations when new stages start. + val failedGeneration = new HashMap[String, Long] val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now @@ -429,7 +433,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val status = event.result.asInstanceOf[MapStatus] val host = status.address.ip logInfo("ShuffleMapTask finished with host " + host) - if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos + if (failedGeneration.contains(host) && smt.generation <= failedGeneration(host)) { + logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + host) + } else { stage.addOutputLoc(smt.partition, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { @@ -495,7 +501,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock // TODO: mark the host as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleHostLost(bmAddress.ip) + handleHostLost(bmAddress.ip, Some(task.generation)) } case other => @@ -507,11 +513,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with /** * Responds to a host being lost. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside. + * + * Optionally the generation during which the failure was caught can be passed to avoid allowing + * stray fetch failures from possibly retriggering the detection of a node as lost. */ - def handleHostLost(host: String) { - if (!deadHosts.contains(host)) { + def handleHostLost(host: String, maybeGeneration: Option[Long] = None) { + val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) + if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) { + failedGeneration(host) = currentGeneration logInfo("Host lost: " + host) - deadHosts += host env.blockManager.master.notifyADeadHost(host) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { @@ -519,6 +529,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray mapOutputTracker.registerMapOutputs(shuffleId, locs, true) } + if (shuffleToMapStage.isEmpty) { + mapOutputTracker.incrementGeneration() + } cacheTracker.cacheLost(host) updateCacheLocs() } -- cgit v1.2.3 From 7e9ee2e8335f085062d3fdeecd0b49ec63e92117 Mon Sep 17 00:00:00 2001 From: Leemoonsoo Date: Tue, 22 Jan 2013 23:08:34 +0900 Subject: Fix for hanging spark.HttpFileServer with kind of virtual network --- core/src/main/scala/spark/HttpServer.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala index 0196595ba1..4e0507c080 100644 --- a/core/src/main/scala/spark/HttpServer.scala +++ b/core/src/main/scala/spark/HttpServer.scala @@ -4,6 +4,7 @@ import java.io.File import java.net.InetAddress import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.bio.SocketConnector import org.eclipse.jetty.server.handler.DefaultHandler import org.eclipse.jetty.server.handler.HandlerList import org.eclipse.jetty.server.handler.ResourceHandler @@ -27,7 +28,13 @@ private[spark] class HttpServer(resourceBase: File) extends Logging { if (server != null) { throw new ServerStateException("Server is already started") } else { - server = new Server(0) + server = new Server() + val connector = new SocketConnector + connector.setMaxIdleTime(60*1000) + connector.setSoLingerTime(-1) + connector.setPort(0) + server.addConnector(connector) + val threadPool = new QueuedThreadPool threadPool.setDaemon(true) server.setThreadPool(threadPool) -- cgit v1.2.3 From 588b24197a85c4b46a38595007293abef9a41f2c Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 10:19:30 -0600 Subject: Use default arguments instead of constructor overloads. --- core/src/main/scala/spark/SparkContext.scala | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 8b6f4b3b7d..495d1b6c78 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -58,27 +58,11 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend class SparkContext( val master: String, val jobName: String, - val sparkHome: String, - val jars: Seq[String], - environment: Map[String, String]) + val sparkHome: String = null, + val jars: Seq[String] = Nil, + environment: Map[String, String] = Map()) extends Logging { - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI - * @param sparkHome Location where Spark is installed on cluster nodes. - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - */ - def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) = - this(master, jobName, sparkHome, jars, Map()) - - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI - */ - def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map()) - // Ensure logging is initialized before we spawn any threads initLogging() -- cgit v1.2.3 From 50e2b23927956c14db40093d31bc80892764006a Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 22 Jan 2013 09:27:33 -0800 Subject: Fix up some problems from the merge --- .../main/scala/spark/storage/BlockManagerMasterActor.scala | 11 +++++++++++ core/src/main/scala/spark/storage/BlockManagerMessages.scala | 3 +++ core/src/main/scala/spark/storage/StorageUtils.scala | 8 ++++---- 3 files changed, 18 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index f4d026da33..c945c34c71 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -68,6 +68,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { case GetMemoryStatus => getMemoryStatus + case GetStorageStatus => + getStorageStatus + case RemoveBlock(blockId) => removeBlock(blockId) @@ -177,6 +180,14 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! res } + private def getStorageStatus() { + val res = blockManagerInfo.map { case(blockManagerId, info) => + import collection.JavaConverters._ + StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap) + } + sender ! res + } + private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockManagerId + " " diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index d73a9b790f..3a381fd385 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -100,3 +100,6 @@ case object GetMemoryStatus extends ToBlockManagerMaster private[spark] case object ExpireDeadHosts extends ToBlockManagerMaster + +private[spark] +case object GetStorageStatus extends ToBlockManagerMaster \ No newline at end of file diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index ebc7390ee5..63ad5c125b 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -1,6 +1,7 @@ package spark.storage import spark.SparkContext +import BlockManagerMasterActor.BlockStatus private[spark] case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, @@ -20,8 +21,8 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, } -case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numPartitions: Int, memSize: Long, diskSize: Long, locations: Array[BlockManagerId]) +case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, + numPartitions: Int, memSize: Long, diskSize: Long) /* Helper methods for storage-related objects */ @@ -58,8 +59,7 @@ object StorageUtils { val rddName = Option(sc.persistentRdds.get(rddId).name).getOrElse(rddKey) val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel - RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize, - rddBlocks.map(_.blockManagerId)) + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize) }.toArray } -- cgit v1.2.3 From 27b3f3f0a980f86bac14a14516b5d52a32aa8cbb Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 15:30:42 -0600 Subject: Handle slaveLost before slaveIdToHost knows about it. --- .../spark/scheduler/cluster/ClusterScheduler.scala | 31 +++++++++++++--------- 1 file changed, 18 insertions(+), 13 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 20f6e65020..a639b72795 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -252,19 +252,24 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def slaveLost(slaveId: String, reason: ExecutorLossReason) { var failedHost: Option[String] = None synchronized { - val host = slaveIdToHost(slaveId) - if (hostsAlive.contains(host)) { - logError("Lost an executor on " + host + ": " + reason) - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - activeTaskSetsQueue.foreach(_.hostLost(host)) - failedHost = Some(host) - } else { - // We may get multiple slaveLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor on " + host + " (already removed): " + reason) + slaveIdToHost.get(slaveId) match { + case Some(host) => + if (hostsAlive.contains(host)) { + logError("Lost an executor on " + host + ": " + reason) + slaveIdsWithExecutors -= slaveId + hostsAlive -= host + activeTaskSetsQueue.foreach(_.hostLost(host)) + failedHost = Some(host) + } else { + // We may get multiple slaveLost() calls with different loss reasons. For example, one + // may be triggered by a dropped connection from the slave while another may be a report + // of executor termination from Mesos. We produce log messages for both so we eventually + // report the termination reason. + logError("Lost an executor on " + host + " (already removed): " + reason) + } + case None => + // We were told about a slave being lost before we could even allocate work to it + logError("Lost slave " + slaveId + " (no work assigned yet)") } } if (failedHost != None) { -- cgit v1.2.3 From 6f2194f7576eb188c23f18125f5101ae0b4e9e4d Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 15:38:58 -0600 Subject: Call removeJob instead of killing the cluster. --- core/src/main/scala/spark/deploy/master/Master.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 2c2cd0231b..d1a65204b8 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -103,8 +103,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor val e = new SparkException("Job %s wth ID %s failed %d times.".format( jobInfo.desc.name, jobInfo.id, jobInfo.retryCount)) logError(e.getMessage, e) - throw e - //System.exit(1) + removeJob(jobInfo) } } } -- cgit v1.2.3 From 250fe89679bb59ef0d31f74985f72556dcfe2d06 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 16:29:05 -0600 Subject: Handle Master telling the Worker to kill an already-dead executor. --- core/src/main/scala/spark/deploy/worker/Worker.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 19bf2be118..d040b86908 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -143,9 +143,13 @@ private[spark] class Worker( case KillExecutor(jobId, execId) => val fullId = jobId + "/" + execId - val executor = executors(fullId) - logInfo("Asked to kill executor " + fullId) - executor.kill() + executors.get(fullId) match { + case Some(executor) => + logInfo("Asked to kill executor " + fullId) + executor.kill() + case None => + logInfo("Asked to kill non-existent existent " + fullId) + } case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => masterDisconnected() -- cgit v1.2.3 From 2437f6741b9c5b0a778d55d324aabdc4642889e5 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 18:01:03 -0600 Subject: Restore SPARK_MEM in executorEnvs. --- core/src/main/scala/spark/SparkContext.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index a5a1b75944..402355bd52 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -111,8 +111,9 @@ class SparkContext( // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() - // Note: SPARK_MEM isn't included because it's set directly in ExecutorRunner - for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) { + // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner + for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", + "SPARK_TESTING")) { val value = System.getenv(key) if (value != null) { executorEnvs(key) = value -- cgit v1.2.3 From fdec42385a1a8f10f9dd803525cb3c132a25ba53 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 18:01:12 -0600 Subject: Fix SPARK_MEM in ExecutorRunner. --- core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 2f2ea617ff..e910416235 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -118,7 +118,7 @@ private[spark] class ExecutorRunner( for ((key, value) <- jobDesc.command.environment) { env.put(key, value) } - env.put("SPARK_MEM", memory.toString) + env.put("SPARK_MEM", memory.toString + "m") // In case we are running this from within the Spark Shell, avoid creating a "scala" // parent process for the executor command env.put("SPARK_LAUNCH_WITH_SCALA", "0") -- cgit v1.2.3 From 8c51322cd05f2ae97a08c3af314c7608fcf71b57 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 18:09:10 -0600 Subject: Don't bother creating an exception. --- core/src/main/scala/spark/deploy/master/Master.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index d1a65204b8..361e5ac627 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -100,9 +100,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) { schedule() } else { - val e = new SparkException("Job %s wth ID %s failed %d times.".format( + logError("Job %s wth ID %s failed %d times, removing it".format( jobInfo.desc.name, jobInfo.id, jobInfo.retryCount)) - logError(e.getMessage, e) removeJob(jobInfo) } } -- cgit v1.2.3 From 98d0b7747d7539db009a9bbc261f899955871524 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 18:11:51 -0600 Subject: Fix Worker logInfo about unknown executor. --- core/src/main/scala/spark/deploy/worker/Worker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index d040b86908..5a83a42daf 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -148,7 +148,7 @@ private[spark] class Worker( logInfo("Asked to kill executor " + fullId) executor.kill() case None => - logInfo("Asked to kill non-existent existent " + fullId) + logInfo("Asked to kill unknown executor " + fullId) } case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => -- cgit v1.2.3 From 284993100022cc4bd43bf84a0be4dd91cf7a4ac0 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 22 Jan 2013 22:19:30 -0800 Subject: Eliminate CacheTracker. Replaces DAGScheduler's queries of CacheTracker with BlockManagerMaster queries. Adds CacheManager to locally coordinate computation of cached RDDs. --- core/src/main/scala/spark/CacheTracker.scala | 240 --------------------- core/src/main/scala/spark/RDD.scala | 2 +- core/src/main/scala/spark/SparkEnv.scala | 8 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 24 ++- .../main/scala/spark/storage/BlockManager.scala | 24 +-- core/src/test/scala/spark/CacheTrackerSuite.scala | 131 ----------- 6 files changed, 18 insertions(+), 411 deletions(-) delete mode 100644 core/src/main/scala/spark/CacheTracker.scala delete mode 100644 core/src/test/scala/spark/CacheTrackerSuite.scala (limited to 'core') diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala deleted file mode 100644 index 86ad737583..0000000000 --- a/core/src/main/scala/spark/CacheTracker.scala +++ /dev/null @@ -1,240 +0,0 @@ -package spark - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import akka.actor._ -import akka.dispatch._ -import akka.pattern.ask -import akka.remote._ -import akka.util.Duration -import akka.util.Timeout -import akka.util.duration._ - -import spark.storage.BlockManager -import spark.storage.StorageLevel -import util.{TimeStampedHashSet, MetadataCleaner, TimeStampedHashMap} - -private[spark] sealed trait CacheTrackerMessage - -private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L) - extends CacheTrackerMessage -private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L) - extends CacheTrackerMessage -private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage -private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage -private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage -private[spark] case object GetCacheStatus extends CacheTrackerMessage -private[spark] case object GetCacheLocations extends CacheTrackerMessage -private[spark] case object StopCacheTracker extends CacheTrackerMessage - -private[spark] class CacheTrackerActor extends Actor with Logging { - // TODO: Should probably store (String, CacheType) tuples - private val locs = new TimeStampedHashMap[Int, Array[List[String]]] - - /** - * A map from the slave's host name to its cache size. - */ - private val slaveCapacity = new HashMap[String, Long] - private val slaveUsage = new HashMap[String, Long] - - private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.clearOldValues) - - private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) - private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) - private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host) - - def receive = { - case SlaveCacheStarted(host: String, size: Long) => - slaveCapacity.put(host, size) - slaveUsage.put(host, 0) - sender ! true - - case RegisterRDD(rddId: Int, numPartitions: Int) => - logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions") - locs(rddId) = Array.fill[List[String]](numPartitions)(Nil) - sender ! true - - case AddedToCache(rddId, partition, host, size) => - slaveUsage.put(host, getCacheUsage(host) + size) - locs(rddId)(partition) = host :: locs(rddId)(partition) - sender ! true - - case DroppedFromCache(rddId, partition, host, size) => - slaveUsage.put(host, getCacheUsage(host) - size) - // Do a sanity check to make sure usage is greater than 0. - locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host) - sender ! true - - case MemoryCacheLost(host) => - logInfo("Memory cache lost on " + host) - for ((id, locations) <- locs) { - for (i <- 0 until locations.length) { - locations(i) = locations(i).filterNot(_ == host) - } - } - sender ! true - - case GetCacheLocations => - logInfo("Asked for current cache locations") - sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())} - - case GetCacheStatus => - val status = slaveCapacity.map { case (host, capacity) => - (host, capacity, getCacheUsage(host)) - }.toSeq - sender ! status - - case StopCacheTracker => - logInfo("Stopping CacheTrackerActor") - sender ! true - metadataCleaner.cancel() - context.stop(self) - } -} - -private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager) - extends Logging { - - // Tracker actor on the master, or remote reference to it on workers - val ip: String = System.getProperty("spark.master.host", "localhost") - val port: Int = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "CacheTracker" - - val timeout = 10.seconds - - var trackerActor: ActorRef = if (isMaster) { - val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName) - logInfo("Registered CacheTrackerActor actor") - actor - } else { - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - actorSystem.actorFor(url) - } - - // TODO: Consider removing this HashSet completely as locs CacheTrackerActor already - // keeps track of registered RDDs - val registeredRddIds = new TimeStampedHashSet[Int] - - // Remembers which splits are currently being loaded (on worker nodes) - val loading = new HashSet[String] - - val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.clearOldValues) - - // Send a message to the trackerActor and get its result within a default timeout, or - // throw a SparkException if this fails. - def askTracker(message: Any): Any = { - try { - val future = trackerActor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with CacheTracker", e) - } - } - - // Send a one-way message to the trackerActor, to which we expect it to reply with true. - def communicate(message: Any) { - if (askTracker(message) != true) { - throw new SparkException("Error reply received from CacheTracker") - } - } - - // Registers an RDD (on master only) - def registerRDD(rddId: Int, numPartitions: Int) { - registeredRddIds.synchronized { - if (!registeredRddIds.contains(rddId)) { - logInfo("Registering RDD ID " + rddId + " with cache") - registeredRddIds += rddId - communicate(RegisterRDD(rddId, numPartitions)) - } - } - } - - // For BlockManager.scala only - def cacheLost(host: String) { - communicate(MemoryCacheLost(host)) - logInfo("CacheTracker successfully removed entries on " + host) - } - - // Get the usage status of slave caches. Each tuple in the returned sequence - // is in the form of (host name, capacity, usage). - def getCacheStatus(): Seq[(String, Long, Long)] = { - askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]] - } - - // For BlockManager.scala only - def notifyFromBlockManager(t: AddedToCache) { - communicate(t) - } - - // Get a snapshot of the currently known locations - def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { - askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]] - } - - // Gets or computes an RDD split - def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) - : Iterator[T] = { - val key = "rdd_%d_%d".format(rdd.id, split.index) - logInfo("Cache key is " + key) - blockManager.get(key) match { - case Some(cachedValues) => - // Split is in cache, so just return its values - logInfo("Found partition in cache!") - return cachedValues.asInstanceOf[Iterator[T]] - - case None => - // Mark the split as loading (unless someone else marks it first) - loading.synchronized { - if (loading.contains(key)) { - logInfo("Loading contains " + key + ", waiting...") - while (loading.contains(key)) { - try {loading.wait()} catch {case _ =>} - } - logInfo("Loading no longer contains " + key + ", so returning cached result") - // See whether someone else has successfully loaded it. The main way this would fail - // is for the RDD-level cache eviction policy if someone else has loaded the same RDD - // partition but we didn't want to make space for it. However, that case is unlikely - // because it's unlikely that two threads would work on the same RDD partition. One - // downside of the current code is that threads wait serially if this does happen. - blockManager.get(key) match { - case Some(values) => - return values.asInstanceOf[Iterator[T]] - case None => - logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") - loading.add(key) - } - } else { - loading.add(key) - } - } - try { - // If we got here, we have to load the split - val elements = new ArrayBuffer[Any] - logInfo("Computing partition " + split) - elements ++= rdd.compute(split, context) - // Try to put this block in the blockManager - blockManager.put(key, elements, storageLevel, true) - return elements.iterator.asInstanceOf[Iterator[T]] - } finally { - loading.synchronized { - loading.remove(key) - loading.notifyAll() - } - } - } - } - - // Called by the Cache to report that an entry has been dropped from it - def dropEntry(rddId: Int, partition: Int) { - communicate(DroppedFromCache(rddId, partition, Utils.localHostName())) - } - - def stop() { - communicate(StopCacheTracker) - registeredRddIds.clear() - trackerActor = null - } -} diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e0d2eabb1d..c79f34342f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -176,7 +176,7 @@ abstract class RDD[T: ClassManifest]( if (isCheckpointed) { checkpointData.get.iterator(split, context) } else if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel) + SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) } else { compute(split, context) } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 41441720a7..a080194980 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -22,7 +22,7 @@ class SparkEnv ( val actorSystem: ActorSystem, val serializer: Serializer, val closureSerializer: Serializer, - val cacheTracker: CacheTracker, + val cacheManager: CacheManager, val mapOutputTracker: MapOutputTracker, val shuffleFetcher: ShuffleFetcher, val broadcastManager: BroadcastManager, @@ -39,7 +39,6 @@ class SparkEnv ( def stop() { httpFileServer.stop() mapOutputTracker.stop() - cacheTracker.stop() shuffleFetcher.stop() broadcastManager.stop() blockManager.stop() @@ -100,8 +99,7 @@ object SparkEnv extends Logging { val closureSerializer = instantiateClass[Serializer]( "spark.closure.serializer", "spark.JavaSerializer") - val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager) - blockManager.cacheTracker = cacheTracker + val cacheManager = new CacheManager(blockManager) val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster) @@ -122,7 +120,7 @@ object SparkEnv extends Logging { actorSystem, serializer, closureSerializer, - cacheTracker, + cacheManager, mapOutputTracker, shuffleFetcher, broadcastManager, diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 59f2099e91..03d173ac3b 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -69,8 +69,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with var cacheLocs = new HashMap[Int, Array[List[String]]] val env = SparkEnv.get - val cacheTracker = env.cacheTracker val mapOutputTracker = env.mapOutputTracker + val blockManagerMaster = env.blockManager.master val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; // that's not going to be a realistic assumption in general @@ -95,11 +95,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with }.start() def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { + if (!cacheLocs.contains(rdd.id)) { + val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray + cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { + locations => locations.map(_.ip).toList + }.toArray + } cacheLocs(rdd.id) } - def updateCacheLocs() { - cacheLocs = cacheTracker.getLocationsSnapshot() + def clearCacheLocs() { + cacheLocs.clear } /** @@ -126,7 +132,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of splits is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") - cacheTracker.registerRDD(rdd.id, rdd.splits.size) if (shuffleDep != None) { mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) } @@ -148,8 +153,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with visited += r // Kind of ugly: need to register RDDs with the cache here since // we can't do it in its constructor because # of splits is unknown - logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")") - cacheTracker.registerRDD(r.id, r.splits.size) for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => @@ -250,7 +253,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val runId = nextRunId.getAndIncrement() val finalStage = newStage(finalRDD, None, runId) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) - updateCacheLocs() + clearCacheLocs() logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + " output partitions") logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") @@ -293,7 +296,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // on the failed node. if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { logInfo("Resubmitting failed stages") - updateCacheLocs() + clearCacheLocs() val failed2 = failed.toArray failed.clear() for (stage <- failed2.sortBy(_.priority)) { @@ -443,7 +446,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with stage.shuffleDep.get.shuffleId, stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray) } - updateCacheLocs() + clearCacheLocs() if (stage.outputLocs.count(_ == Nil) != 0) { // Some tasks had failed; let's resubmit this stage // TODO: Lower-level scheduler should also deal with this @@ -519,8 +522,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray mapOutputTracker.registerMapOutputs(shuffleId, locs, true) } - cacheTracker.cacheLost(host) - updateCacheLocs() + clearCacheLocs() } } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 7a8ac10cdd..e049565f48 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -16,7 +16,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils} +import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils} import spark.network._ import spark.serializer.Serializer import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} @@ -71,9 +71,6 @@ class BlockManager( val connectionManagerId = connectionManager.id val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) - // TODO: This will be removed after cacheTracker is removed from the code base. - var cacheTracker: CacheTracker = null - // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) val maxBytesInFlight = @@ -662,10 +659,6 @@ class BlockManager( BlockManager.dispose(bytesAfterPut) - // TODO: This code will be removed when CacheTracker is gone. - if (blockId.startsWith("rdd")) { - notifyCacheTracker(blockId) - } logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)) return size @@ -733,11 +726,6 @@ class BlockManager( } } - // TODO: This code will be removed when CacheTracker is gone. - if (blockId.startsWith("rdd")) { - notifyCacheTracker(blockId) - } - // If replication had started, then wait for it to finish if (level.replication > 1) { if (replicationFuture == null) { @@ -780,16 +768,6 @@ class BlockManager( } } - // TODO: This code will be removed when CacheTracker is gone. - private def notifyCacheTracker(key: String) { - if (cacheTracker != null) { - val rddInfo = key.split("_") - val rddId: Int = rddInfo(1).toInt - val partition: Int = rddInfo(2).toInt - cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host)) - } - } - /** * Read a block consisting of a single object. */ diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala deleted file mode 100644 index 467605981b..0000000000 --- a/core/src/test/scala/spark/CacheTrackerSuite.scala +++ /dev/null @@ -1,131 +0,0 @@ -package spark - -import org.scalatest.FunSuite - -import scala.collection.mutable.HashMap - -import akka.actor._ -import akka.dispatch._ -import akka.pattern.ask -import akka.remote._ -import akka.util.Duration -import akka.util.Timeout -import akka.util.duration._ - -class CacheTrackerSuite extends FunSuite { - // Send a message to an actor and wait for a reply, in a blocking manner - private def ask(actor: ActorRef, message: Any): Any = { - try { - val timeout = 10.seconds - val future = actor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with actor", e) - } - } - - test("CacheTrackerActor slave initialization & cache status") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 0L))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - test("RegisterRDD") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, RegisterRDD(1, 3)) === true) - assert(ask(tracker, RegisterRDD(2, 1)) === true) - - assert(getCacheLocations(tracker) === Map(1 -> List(Nil, Nil, Nil), 2 -> List(Nil))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - test("AddedToCache") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, RegisterRDD(1, 2)) === true) - assert(ask(tracker, RegisterRDD(2, 1)) === true) - - assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true) - assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true) - assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L))) - - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - test("DroppedFromCache") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, RegisterRDD(1, 2)) === true) - assert(ask(tracker, RegisterRDD(2, 1)) === true) - - assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true) - assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true) - assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L))) - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - - assert(ask(tracker, DroppedFromCache(1, 1, "host001", 2L << 11)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 68608L))) - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"),List()), 2 -> List(List("host001")))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - /** - * Helper function to get cacheLocations from CacheTracker - */ - def getCacheLocations(tracker: ActorRef): HashMap[Int, List[List[String]]] = { - val answer = ask(tracker, GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]] - answer.map { case (i, arr) => (i, arr.toList) } - } -} -- cgit v1.2.3 From 43e9ff959645e533bcfa0a5c31e62e32c7e9d0a6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Jan 2013 22:47:26 -0800 Subject: Add test for driver hanging on exit (SPARK-530). --- core/src/test/scala/spark/DriverSuite.scala | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 core/src/test/scala/spark/DriverSuite.scala (limited to 'core') diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala new file mode 100644 index 0000000000..70a7c8bc2f --- /dev/null +++ b/core/src/test/scala/spark/DriverSuite.scala @@ -0,0 +1,31 @@ +package spark + +import java.io.File + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts +import org.scalatest.prop.TableDrivenPropertyChecks._ +import org.scalatest.time.SpanSugar._ + +class DriverSuite extends FunSuite with Timeouts { + test("driver should exit after finishing") { + // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" + val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) + forAll(masters) { (master: String) => + failAfter(10 seconds) { + Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME"))) + } + } + } +} + +/** + * Program that creates a Spark driver but doesn't call SparkContext.stop() or + * Sys.exit() after finishing. + */ +object DriverWithoutCleanup { + def main(args: Array[String]) { + val sc = new SparkContext(args(0), "DriverWithoutCleanup") + sc.parallelize(1 to 100, 4).count() + } +} \ No newline at end of file -- cgit v1.2.3 From bacade6caf7527737dc6f02b1c2ca9114e02d8bc Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Jan 2013 22:55:26 -0800 Subject: Modified BlockManagerId API to ensure zero duplicate objects. Fixed BlockManagerId testcase in BlockManagerTestSuite. --- .../src/main/scala/spark/scheduler/MapStatus.scala | 2 +- .../main/scala/spark/storage/BlockManager.scala | 2 +- .../main/scala/spark/storage/BlockManagerId.scala | 33 ++++++++++++++++++---- .../scala/spark/storage/BlockManagerMessages.scala | 3 +- .../test/scala/spark/MapOutputTrackerSuite.scala | 22 +++++++-------- .../scala/spark/storage/BlockManagerSuite.scala | 18 ++++++------ 6 files changed, 51 insertions(+), 29 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala index 4532d9497f..fae643f3a8 100644 --- a/core/src/main/scala/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/spark/scheduler/MapStatus.scala @@ -20,7 +20,7 @@ private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: } def readExternal(in: ObjectInput) { - address = new BlockManagerId(in) + address = BlockManagerId(in) compressedSizes = new Array[Byte](in.readInt()) in.readFully(compressedSizes) } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 7a8ac10cdd..596a69c583 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -69,7 +69,7 @@ class BlockManager( implicit val futureExecContext = connectionManager.futureExecContext val connectionManagerId = connectionManager.id - val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) + val blockManagerId = BlockManagerId(connectionManagerId.host, connectionManagerId.port) // TODO: This will be removed after cacheTracker is removed from the code base. var cacheTracker: CacheTracker = null diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index 488679f049..26c98f2ac8 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -3,20 +3,35 @@ package spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap +/** + * This class represent an unique identifier for a BlockManager. + * The first 2 constructors of this class is made private to ensure that + * BlockManagerId objects can be created only using the factory method in + * [[spark.storage.BlockManager$]]. This allows de-duplication of id objects. + * Also, constructor parameters are private to ensure that parameters cannot + * be modified from outside this class. + */ +private[spark] class BlockManagerId private ( + private var ip_ : String, + private var port_ : Int + ) extends Externalizable { + + private def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) -private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { def this() = this(null, 0) // For deserialization only - def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) + def ip = ip_ + + def port = port_ override def writeExternal(out: ObjectOutput) { - out.writeUTF(ip) - out.writeInt(port) + out.writeUTF(ip_) + out.writeInt(port_) } override def readExternal(in: ObjectInput) { - ip = in.readUTF() - port = in.readInt() + ip_ = in.readUTF() + port_ = in.readInt() } @throws(classOf[IOException]) @@ -35,6 +50,12 @@ private[spark] class BlockManagerId(var ip: String, var port: Int) extends Exter private[spark] object BlockManagerId { + def apply(ip: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(ip, port)) + + def apply(in: ObjectInput) = + getCachedBlockManagerId(new BlockManagerId(in)) + val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index d73a9b790f..7437fc63eb 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -54,8 +54,7 @@ class UpdateBlockInfo( } override def readExternal(in: ObjectInput) { - blockManagerId = new BlockManagerId() - blockManagerId.readExternal(in) + blockManagerId = BlockManagerId(in) blockId = in.readUTF() storageLevel = new StorageLevel() storageLevel.readExternal(in) diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index d3dd3a8fa4..095f415978 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -47,13 +47,13 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), Array(compressedSize10000, compressedSize1000))) val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((new BlockManagerId("hostA", 1000), size1000), - (new BlockManagerId("hostB", 1000), size10000))) + assert(statuses.toSeq === Seq((BlockManagerId("hostA", 1000), size1000), + (BlockManagerId("hostB", 1000), size10000))) tracker.stop() } @@ -65,14 +65,14 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) // As if we had two simulatenous fetch failures - tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) - tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) // The remaining reduce task might try to grab the output dispite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the @@ -95,13 +95,13 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize1000 = MapOutputTracker.compressSize(1000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) masterTracker.registerMapOutput(10, 0, new MapStatus( - new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + BlockManagerId("hostA", 1000), Array(compressedSize1000))) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((new BlockManagerId("hostA", 1000), size1000))) + Seq((BlockManagerId("hostA", 1000), size1000))) - masterTracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + masterTracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 8f86e3170e..a33d3324ba 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -82,16 +82,18 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("BlockManagerId object caching") { - val id1 = new StorageLevel(false, false, false, 3) - val id2 = new StorageLevel(false, false, false, 3) + val id1 = BlockManagerId("XXX", 1) + val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1 + assert(id2 === id1, "id2 is not same as id1") + assert(id2.eq(id1), "id2 is not the same object as id1") val bytes1 = spark.Utils.serialize(id1) - val id1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1) val bytes2 = spark.Utils.serialize(id2) - val id2_ = spark.Utils.deserialize[StorageLevel](bytes2) - assert(id1_ === id1, "Deserialized id1 not same as original id1") - assert(id2_ === id2, "Deserialized id2 not same as original id1") - assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2") - assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1") + val id2_ = spark.Utils.deserialize[BlockManagerId](bytes2) + assert(id1_ === id1, "Deserialized id1 is not same as original id1") + assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1") + assert(id2_ === id2, "Deserialized id2 is not same as original id2") + assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") } test("master + 1 manager interaction") { -- cgit v1.2.3 From 5e11f1e51f17113abb8d3a5bc261af5ba5ffce94 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Jan 2013 23:42:53 -0800 Subject: Modified StorageLevel API to ensure zero duplicate objects. --- .../main/scala/spark/storage/BlockManager.scala | 5 +-- .../main/scala/spark/storage/BlockMessage.scala | 2 +- .../main/scala/spark/storage/StorageLevel.scala | 47 ++++++++++++++-------- .../scala/spark/storage/BlockManagerSuite.scala | 16 +++++--- 4 files changed, 44 insertions(+), 26 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 596a69c583..ca7eb13ec8 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -191,7 +191,7 @@ class BlockManager( case level => val inMem = level.useMemory && memoryStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) - val storageLevel = new StorageLevel(onDisk, inMem, level.deserialized, level.replication) + val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication) val memSize = if (inMem) memoryStore.getSize(blockId) else 0L val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L (storageLevel, memSize, diskSize, info.tellMaster) @@ -760,8 +760,7 @@ class BlockManager( */ var cachedPeers: Seq[BlockManagerId] = null private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { - val tLevel: StorageLevel = - new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) + val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) if (cachedPeers == null) { cachedPeers = master.getPeers(blockManagerId, level.replication - 1) } diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala index 3f234df654..30d7500e01 100644 --- a/core/src/main/scala/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/spark/storage/BlockMessage.scala @@ -64,7 +64,7 @@ private[spark] class BlockMessage() { val booleanInt = buffer.getInt() val replication = buffer.getInt() - level = new StorageLevel(booleanInt, replication) + level = StorageLevel(booleanInt, replication) val dataLength = buffer.getInt() data = ByteBuffer.allocate(dataLength) diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index e3544e5aae..f2535ae5ae 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -7,25 +7,30 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for - * commonly useful storage levels. + * commonly useful storage levels. The recommended method to create your own storage level + * object is to use `StorageLevel.apply(...)` from the singleton object. */ class StorageLevel( - var useDisk: Boolean, - var useMemory: Boolean, - var deserialized: Boolean, - var replication: Int = 1) + private var useDisk_ : Boolean, + private var useMemory_ : Boolean, + private var deserialized_ : Boolean, + private var replication_ : Int = 1) extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - - assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") - - def this(flags: Int, replication: Int) { + private def this(flags: Int, replication: Int) { this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } def this() = this(false, true, false) // For deserialization + def useDisk = useDisk_ + def useMemory = useMemory_ + def deserialized = deserialized_ + def replication = replication_ + + assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + override def clone(): StorageLevel = new StorageLevel( this.useDisk, this.useMemory, this.deserialized, this.replication) @@ -43,13 +48,13 @@ class StorageLevel( def toInt: Int = { var ret = 0 - if (useDisk) { + if (useDisk_) { ret |= 4 } - if (useMemory) { + if (useMemory_) { ret |= 2 } - if (deserialized) { + if (deserialized_) { ret |= 1 } return ret @@ -57,15 +62,15 @@ class StorageLevel( override def writeExternal(out: ObjectOutput) { out.writeByte(toInt) - out.writeByte(replication) + out.writeByte(replication_) } override def readExternal(in: ObjectInput) { val flags = in.readByte() - useDisk = (flags & 4) != 0 - useMemory = (flags & 2) != 0 - deserialized = (flags & 1) != 0 - replication = in.readByte() + useDisk_ = (flags & 4) != 0 + useMemory_ = (flags & 2) != 0 + deserialized_ = (flags & 1) != 0 + replication_ = in.readByte() } @throws(classOf[IOException]) @@ -91,6 +96,14 @@ object StorageLevel { val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) + /** Create a new StorageLevel object */ + def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) = + getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication)) + + /** Create a new StorageLevel object from its integer representation */ + def apply(flags: Int, replication: Int) = + getCachedStorageLevel(new StorageLevel(flags, replication)) + private[spark] val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index a33d3324ba..a1aeb12f25 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -69,23 +69,29 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("StorageLevel object caching") { - val level1 = new StorageLevel(false, false, false, 3) - val level2 = new StorageLevel(false, false, false, 3) + val level1 = StorageLevel(false, false, false, 3) + val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1 + val level3 = StorageLevel(false, false, false, 2) // this should return a different object + assert(level2 === level1, "level2 is not same as level1") + assert(level2.eq(level1), "level2 is not the same object as level1") + assert(level3 != level1, "level3 is same as level1") val bytes1 = spark.Utils.serialize(level1) val level1_ = spark.Utils.deserialize[StorageLevel](bytes1) val bytes2 = spark.Utils.serialize(level2) val level2_ = spark.Utils.deserialize[StorageLevel](bytes2) assert(level1_ === level1, "Deserialized level1 not same as original level1") - assert(level2_ === level2, "Deserialized level2 not same as original level1") - assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2") - assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1") + assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2") + assert(level2_ === level2, "Deserialized level2 not same as original level2") + assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1") } test("BlockManagerId object caching") { val id1 = BlockManagerId("XXX", 1) val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1 + val id3 = BlockManagerId("XXX", 2) // this should return a different object assert(id2 === id1, "id2 is not same as id1") assert(id2.eq(id1), "id2 is not the same object as id1") + assert(id3 != id1, "id3 is same as id1") val bytes1 = spark.Utils.serialize(id1) val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1) val bytes2 = spark.Utils.serialize(id2) -- cgit v1.2.3 From 155f31398dc83ecb88b4b3e07849a2a8a0a6592f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 23 Jan 2013 01:10:26 -0800 Subject: Made StorageLevel constructor private, and added StorageLevels.create() to the Java API. Updates scala and java programming guides. --- core/src/main/scala/spark/api/java/StorageLevels.java | 11 +++++++++++ core/src/main/scala/spark/storage/StorageLevel.scala | 6 +++--- docs/java-programming-guide.md | 3 ++- docs/scala-programming-guide.md | 3 ++- 4 files changed, 18 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/StorageLevels.java b/core/src/main/scala/spark/api/java/StorageLevels.java index 722af3c06c..5e5845ac3a 100644 --- a/core/src/main/scala/spark/api/java/StorageLevels.java +++ b/core/src/main/scala/spark/api/java/StorageLevels.java @@ -17,4 +17,15 @@ public class StorageLevels { public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2); public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1); public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2); + + /** + * Create a new StorageLevel object. + * @param useDisk saved to disk, if true + * @param useMemory saved to memory, if true + * @param deserialized saved as deserialized objects, if true + * @param replication replication factor + */ + public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) { + return StorageLevel.apply(useDisk, useMemory, deserialized, replication); + } } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index f2535ae5ae..45d6ea2656 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -7,10 +7,10 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for - * commonly useful storage levels. The recommended method to create your own storage level - * object is to use `StorageLevel.apply(...)` from the singleton object. + * commonly useful storage levels. To create your own storage level object, use the factor method + * of the singleton object (`StorageLevel(...)`). */ -class StorageLevel( +class StorageLevel private( private var useDisk_ : Boolean, private var useMemory_ : Boolean, private var deserialized_ : Boolean, diff --git a/docs/java-programming-guide.md b/docs/java-programming-guide.md index 188ca4995e..37a906ea1c 100644 --- a/docs/java-programming-guide.md +++ b/docs/java-programming-guide.md @@ -75,7 +75,8 @@ class has a single abstract method, `call()`, that must be implemented. ## Storage Levels RDD [storage level](scala-programming-guide.html#rdd-persistence) constants, such as `MEMORY_AND_DISK`, are -declared in the [spark.api.java.StorageLevels](api/core/index.html#spark.api.java.StorageLevels) class. +declared in the [spark.api.java.StorageLevels](api/core/index.html#spark.api.java.StorageLevels) class. To +define your own storage level, you can use StorageLevels.create(...). # Other Features diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md index 7350eca837..301b330a79 100644 --- a/docs/scala-programming-guide.md +++ b/docs/scala-programming-guide.md @@ -301,7 +301,8 @@ We recommend going through the following process to select one: * Use the replicated storage levels if you want fast fault recovery (e.g. if using Spark to serve requests from a web application). *All* the storage levels provide full fault tolerance by recomputing lost data, but the replicated ones let you continue running tasks on the RDD without waiting to recompute a lost partition. - + +If you want to define your own storage level (say, with replication factor of 3 instead of 2), then use the function factor method `apply()` of the [`StorageLevel`](api/core/index.html#spark.storage.StorageLevel$) singleton object. # Shared Variables -- cgit v1.2.3 From 9a27062260490336a3bfa97c6efd39b1e7e81573 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 01:34:44 -0800 Subject: Force generation increment after shuffle map stage --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 39a1e6d6c6..d8a9049e81 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -445,9 +445,16 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with logInfo("waiting: " + waiting) logInfo("failed: " + failed) if (stage.shuffleDep != None) { + // We supply true to increment the generation number here in case this is a + // recomputation of the map outputs. In that case, some nodes may have cached + // locations with holes (from when we detected the error) and will need the + // generation incremented to refetch them. + // TODO: Only increment the generation number if this is not the first time + // we registered these map outputs. mapOutputTracker.registerMapOutputs( stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray) + stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + true) } updateCacheLocs() if (stage.outputLocs.count(_ == Nil) != 0) { -- cgit v1.2.3 From d209b6b7641059610f734414ea05e0494b5510b0 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 01:35:14 -0800 Subject: Extra debugging from hostLost() --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index d8a9049e81..740aec2e61 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -528,7 +528,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) { failedGeneration(host) = currentGeneration - logInfo("Host lost: " + host) + logInfo("Host lost: " + host + " (generation " + currentGeneration + ")") env.blockManager.master.notifyADeadHost(host) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { @@ -541,6 +541,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } cacheTracker.cacheLost(host) updateCacheLocs() + } else { + logDebug("Additional host lost message for " + host + + "(generation " + currentGeneration + ")") } } -- cgit v1.2.3 From 0b506dd2ecec909cd514143389d0846db2d194ed Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 01:37:51 -0800 Subject: Add tests of various node failure scenarios. --- core/src/test/scala/spark/DistributedSuite.scala | 72 ++++++++++++++++++++++++ 1 file changed, 72 insertions(+) (limited to 'core') diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index cacc2796b6..0d6b265e54 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -188,4 +188,76 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect() assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE")) } + + test("recover from node failures") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(Seq(true, true), 2) + val singleton = sc.parallelize(Seq(true), 1) + assert(data.count === 2) // force executors to start + val masterId = SparkEnv.get.blockManager.blockManagerId + assert(data.map(markNodeIfIdentity).collect.size === 2) + assert(data.map(failOnMarkedIdentity).collect.size === 2) + } + + test("recover from repeated node failures during shuffle-map") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, false), 2) + val singleton = sc.parallelize(Seq(false), 1) + assert(data.count === 2) + assert(data.map(markNodeIfIdentity).collect.size === 2) + assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2) + } + } + + test("recover from repeated node failures during shuffle-reduce") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, true), 2) + val singleton = sc.parallelize(Seq(false), 1) + assert(data.count === 2) + assert(data.map(markNodeIfIdentity).collect.size === 2) + // This relies on mergeCombiners being used to perform the actual reduce for this + // test to actually be testing what it claims. + val grouped = data.map(x => x -> x).combineByKey( + x => x, + (x: Boolean, y: Boolean) => x, + (x: Boolean, y: Boolean) => failOnMarkedIdentity(x) + ) + assert(grouped.collect.size === 1) + } + } +} + +object DistributedSuite { + // Indicates whether this JVM is marked for failure. + var mark = false + + // Set by test to remember if we are in the driver program so we can assert + // that we are not. + var amMaster = false + + // Act like an identity function, but if the argument is true, set mark to true. + def markNodeIfIdentity(item: Boolean): Boolean = { + if (item) { + assert(!amMaster) + mark = true + } + item + } + + // Act like an identity function, but if mark was set to true previously, fail, + // crashing the entire JVM. + def failOnMarkedIdentity(item: Boolean): Boolean = { + if (mark) { + System.exit(42) + } + item + } } -- cgit v1.2.3 From 79d55700ce2559051ac61cc2fb72a67fd7035926 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 23 Jan 2013 01:57:09 -0800 Subject: One more fix. Made even default constructor of BlockManagerId private to prevent such problems in the future. --- core/src/main/scala/spark/storage/BlockManagerId.scala | 11 ++++++----- core/src/main/scala/spark/storage/BlockManagerMessages.scala | 3 +-- core/src/main/scala/spark/storage/StorageLevel.scala | 7 +++++++ 3 files changed, 14 insertions(+), 7 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index 26c98f2ac8..abb8b45a1f 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -16,9 +16,7 @@ private[spark] class BlockManagerId private ( private var port_ : Int ) extends Externalizable { - private def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) - - def this() = this(null, 0) // For deserialization only + private def this() = this(null, 0) // For deserialization only def ip = ip_ @@ -53,8 +51,11 @@ private[spark] object BlockManagerId { def apply(ip: String, port: Int) = getCachedBlockManagerId(new BlockManagerId(ip, port)) - def apply(in: ObjectInput) = - getCachedBlockManagerId(new BlockManagerId(in)) + def apply(in: ObjectInput) = { + val obj = new BlockManagerId() + obj.readExternal(in) + getCachedBlockManagerId(obj) + } val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index 7437fc63eb..30483b0b37 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -56,8 +56,7 @@ class UpdateBlockInfo( override def readExternal(in: ObjectInput) { blockManagerId = BlockManagerId(in) blockId = in.readUTF() - storageLevel = new StorageLevel() - storageLevel.readExternal(in) + storageLevel = StorageLevel(in) memSize = in.readInt() diskSize = in.readInt() } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index 45d6ea2656..d1d1c61c1c 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -104,6 +104,13 @@ object StorageLevel { def apply(flags: Int, replication: Int) = getCachedStorageLevel(new StorageLevel(flags, replication)) + /** Read StorageLevel object from ObjectInput stream */ + def apply(in: ObjectInput) = { + val obj = new StorageLevel() + obj.readExternal(in) + getCachedStorageLevel(obj) + } + private[spark] val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() -- cgit v1.2.3 From ae2ed2947d43860c74a8d40767e289ca78073977 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 23 Jan 2013 10:36:18 -0800 Subject: Allow PySpark's SparkFiles to be used from driver Fix minor documentation formatting issues. --- core/src/main/scala/spark/SparkFiles.java | 8 ++++---- python/pyspark/context.py | 27 +++++++++++++++++++++------ python/pyspark/files.py | 20 +++++++++++++++++--- python/pyspark/tests.py | 23 +++++++++++++++++++++++ python/pyspark/worker.py | 1 + python/test_support/hello.txt | 1 + 6 files changed, 67 insertions(+), 13 deletions(-) create mode 100755 python/test_support/hello.txt (limited to 'core') diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java index b59d8ce93f..566aec622c 100644 --- a/core/src/main/scala/spark/SparkFiles.java +++ b/core/src/main/scala/spark/SparkFiles.java @@ -3,23 +3,23 @@ package spark; import java.io.File; /** - * Resolves paths to files added through `addFile(). + * Resolves paths to files added through `SparkContext.addFile()`. */ public class SparkFiles { private SparkFiles() {} /** - * Get the absolute path of a file added through `addFile()`. + * Get the absolute path of a file added through `SparkContext.addFile()`. */ public static String get(String filename) { return new File(getRootDirectory(), filename).getAbsolutePath(); } /** - * Get the root directory that contains files added through `addFile()`. + * Get the root directory that contains files added through `SparkContext.addFile()`. */ public static String getRootDirectory() { return SparkEnv.get().sparkFilesDir(); } -} \ No newline at end of file +} diff --git a/python/pyspark/context.py b/python/pyspark/context.py index b8d7dc05af..3e33776af0 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1,12 +1,15 @@ import os import atexit import shutil +import sys import tempfile +from threading import Lock from tempfile import NamedTemporaryFile from pyspark import accumulators from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast +from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched from pyspark.rdd import RDD @@ -27,6 +30,8 @@ class SparkContext(object): _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _takePartition = jvm.PythonRDD.takePartition _next_accum_id = 0 + _active_spark_context = None + _lock = Lock() def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -46,6 +51,11 @@ class SparkContext(object): Java object. Set 1 to disable batching or -1 to use an unlimited batch size. """ + with SparkContext._lock: + if SparkContext._active_spark_context: + raise ValueError("Cannot run multiple SparkContexts at once") + else: + SparkContext._active_spark_context = self self.master = master self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J @@ -75,6 +85,8 @@ class SparkContext(object): # Deploy any code dependencies specified in the constructor for path in (pyFiles or []): self.addPyFile(path) + SparkFiles._sc = self + sys.path.append(SparkFiles.getRootDirectory()) @property def defaultParallelism(self): @@ -85,17 +97,20 @@ class SparkContext(object): return self._jsc.sc().defaultParallelism() def __del__(self): - if self._jsc: - self._jsc.stop() - if self._accumulatorServer: - self._accumulatorServer.shutdown() + self.stop() def stop(self): """ Shut down the SparkContext. """ - self._jsc.stop() - self._jsc = None + if self._jsc: + self._jsc.stop() + self._jsc = None + if self._accumulatorServer: + self._accumulatorServer.shutdown() + self._accumulatorServer = None + with SparkContext._lock: + SparkContext._active_spark_context = None def parallelize(self, c, numSlices=None): """ diff --git a/python/pyspark/files.py b/python/pyspark/files.py index de1334f046..98f6a399cc 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -4,13 +4,15 @@ import os class SparkFiles(object): """ Resolves paths to files added through - L{addFile()}. + L{SparkContext.addFile()}. SparkFiles contains only classmethods; users should not create SparkFiles instances. """ _root_directory = None + _is_running_on_worker = False + _sc = None def __init__(self): raise NotImplementedError("Do not construct SparkFiles objects") @@ -18,7 +20,19 @@ class SparkFiles(object): @classmethod def get(cls, filename): """ - Get the absolute path of a file added through C{addFile()}. + Get the absolute path of a file added through C{SparkContext.addFile()}. """ - path = os.path.join(SparkFiles._root_directory, filename) + path = os.path.join(SparkFiles.getRootDirectory(), filename) return os.path.abspath(path) + + @classmethod + def getRootDirectory(cls): + """ + Get the root directory that contains files added through + C{SparkContext.addFile()}. + """ + if cls._is_running_on_worker: + return cls._root_directory + else: + # This will have to change if we support multiple SparkContexts: + return cls._sc.jvm.spark.SparkFiles.getRootDirectory() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 4d70ee4f12..46ab34f063 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -4,22 +4,26 @@ individual modules. """ import os import shutil +import sys from tempfile import NamedTemporaryFile import time import unittest from pyspark.context import SparkContext +from pyspark.files import SparkFiles from pyspark.java_gateway import SPARK_HOME class PySparkTestCase(unittest.TestCase): def setUp(self): + self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ self.sc = SparkContext('local[4]', class_name , batchSize=2) def tearDown(self): self.sc.stop() + sys.path = self._old_sys_path # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown self.sc.jvm.System.clearProperty("spark.master.port") @@ -84,6 +88,25 @@ class TestAddFile(PySparkTestCase): res = self.sc.parallelize(range(2)).map(func).first() self.assertEqual("Hello World!", res) + def test_add_file_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello.txt") + self.sc.addFile(path) + download_path = SparkFiles.get("hello.txt") + self.assertNotEqual(path, download_path) + with open(download_path) as test_file: + self.assertEquals("Hello World!\n", test_file.readline()) + + def test_add_py_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlibrary import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addFile(path) + from userlibrary import UserClass + self.assertEqual("Hello World!", UserClass().hello()) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4bf643da66..d33d6dd15f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -26,6 +26,7 @@ def main(): split_index = read_int(sys.stdin) spark_files_dir = load_pickle(read_with_length(sys.stdin)) SparkFiles._root_directory = spark_files_dir + SparkFiles._is_running_on_worker = True sys.path.append(spark_files_dir) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): diff --git a/python/test_support/hello.txt b/python/test_support/hello.txt new file mode 100755 index 0000000000..980a0d5f19 --- /dev/null +++ b/python/test_support/hello.txt @@ -0,0 +1 @@ +Hello World! -- cgit v1.2.3 From e1027ca6398fd5b1a99a2203df840911c4dccb27 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:22:11 -0800 Subject: Actually add CacheManager. --- core/src/main/scala/spark/CacheManager.scala | 65 ++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 core/src/main/scala/spark/CacheManager.scala (limited to 'core') diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala new file mode 100644 index 0000000000..a0b53fd9d6 --- /dev/null +++ b/core/src/main/scala/spark/CacheManager.scala @@ -0,0 +1,65 @@ +package spark + +import scala.collection.mutable.{ArrayBuffer, HashSet} +import spark.storage.{BlockManager, StorageLevel} + + +/** Spark class responsible for passing RDDs split contents to the BlockManager and making + sure a node doesn't load two copies of an RDD at once. + */ +private[spark] class CacheManager(blockManager: BlockManager) extends Logging { + private val loading = new HashSet[String] + + /** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */ + def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) + : Iterator[T] = { + val key = "rdd_%d_%d".format(rdd.id, split.index) + logInfo("Cache key is " + key) + blockManager.get(key) match { + case Some(cachedValues) => + // Split is in cache, so just return its values + logInfo("Found partition in cache!") + return cachedValues.asInstanceOf[Iterator[T]] + + case None => + // Mark the split as loading (unless someone else marks it first) + loading.synchronized { + if (loading.contains(key)) { + logInfo("Loading contains " + key + ", waiting...") + while (loading.contains(key)) { + try {loading.wait()} catch {case _ =>} + } + logInfo("Loading no longer contains " + key + ", so returning cached result") + // See whether someone else has successfully loaded it. The main way this would fail + // is for the RDD-level cache eviction policy if someone else has loaded the same RDD + // partition but we didn't want to make space for it. However, that case is unlikely + // because it's unlikely that two threads would work on the same RDD partition. One + // downside of the current code is that threads wait serially if this does happen. + blockManager.get(key) match { + case Some(values) => + return values.asInstanceOf[Iterator[T]] + case None => + logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") + loading.add(key) + } + } else { + loading.add(key) + } + } + try { + // If we got here, we have to load the split + val elements = new ArrayBuffer[Any] + logInfo("Computing partition " + split) + elements ++= rdd.compute(split, context) + // Try to put this block in the blockManager + blockManager.put(key, elements, storageLevel, true) + return elements.iterator.asInstanceOf[Iterator[T]] + } finally { + loading.synchronized { + loading.remove(key) + loading.notifyAll() + } + } + } + } +} -- cgit v1.2.3 From 88b9d240fda7ca34c08752dfa66797eecb6db872 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:40:38 -0800 Subject: Remove dead code in test. --- core/src/test/scala/spark/DistributedSuite.scala | 2 -- 1 file changed, 2 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 0d6b265e54..af66d33aa3 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -194,7 +194,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter DistributedSuite.amMaster = true sc = new SparkContext(clusterUrl, "test") val data = sc.parallelize(Seq(true, true), 2) - val singleton = sc.parallelize(Seq(true), 1) assert(data.count === 2) // force executors to start val masterId = SparkEnv.get.blockManager.blockManagerId assert(data.map(markNodeIfIdentity).collect.size === 2) @@ -207,7 +206,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc = new SparkContext(clusterUrl, "test") for (i <- 1 to 3) { val data = sc.parallelize(Seq(true, false), 2) - val singleton = sc.parallelize(Seq(false), 1) assert(data.count === 2) assert(data.map(markNodeIfIdentity).collect.size === 2) assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2) -- cgit v1.2.3 From be4a115a7ec7fb6ec0d34f1a1a1bb2c9bbe7600e Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:48:45 -0800 Subject: Clarify TODO. --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 740aec2e61..14a3ef8ad7 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -76,7 +76,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // sent with every task. When we detect a node failing, we note the current generation number // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask // results. - // TODO: Garbage collect information about failure generations when new stages start. + // TODO: Garbage collect information about failure generations when we know there are no more + // stray messages to detect. val failedGeneration = new HashMap[String, Long] val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done -- cgit v1.2.3 From e1985bfa04ad4583ac1f0f421cbe0182ce7c53df Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 21 Jan 2013 16:21:14 -0800 Subject: be sure to set class loader of kryo instances --- core/src/main/scala/spark/KryoSerializer.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 93d7327324..56919544e8 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -206,5 +206,8 @@ class KryoSerializer extends spark.serializer.Serializer with Logging { kryo } - def newInstance(): SerializerInstance = new KryoSerializerInstance(this) + def newInstance(): SerializerInstance = { + this.kryo.setClassLoader(Thread.currentThread().getContextClassLoader) + new KryoSerializerInstance(this) + } } -- cgit v1.2.3 From 5c7422292ecace947f78e5ebe97e83a355531af7 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:59:51 -0800 Subject: Remove more dead code from test. --- core/src/test/scala/spark/DistributedSuite.scala | 1 - 1 file changed, 1 deletion(-) (limited to 'core') diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index af66d33aa3..0487e06d12 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -218,7 +218,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc = new SparkContext(clusterUrl, "test") for (i <- 1 to 3) { val data = sc.parallelize(Seq(true, true), 2) - val singleton = sc.parallelize(Seq(false), 1) assert(data.count === 2) assert(data.map(markNodeIfIdentity).collect.size === 2) // This relies on mergeCombiners being used to perform the actual reduce for this -- cgit v1.2.3 From 1dd82743e09789f8fdae2f5628545c0cb9f79245 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 23 Jan 2013 13:07:27 -0800 Subject: Fix compile error due to cherry-pick --- core/src/main/scala/spark/KryoSerializer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 56919544e8..0bd73e936b 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -207,7 +207,7 @@ class KryoSerializer extends spark.serializer.Serializer with Logging { } def newInstance(): SerializerInstance = { - this.kryo.setClassLoader(Thread.currentThread().getContextClassLoader) + this.kryo.get().setClassLoader(Thread.currentThread().getContextClassLoader) new KryoSerializerInstance(this) } } -- cgit v1.2.3 From eb222b720647c9e92a867c591cc4914b9a6cb5c1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 15:29:02 -0800 Subject: Added pruntSplits method to RDD. --- core/src/main/scala/spark/RDD.scala | 10 +++++++++ .../main/scala/spark/rdd/SplitsPruningRDD.scala | 24 ++++++++++++++++++++++ core/src/test/scala/spark/RDDSuite.scala | 22 +++++++++++++------- 3 files changed, 49 insertions(+), 7 deletions(-) create mode 100644 core/src/main/scala/spark/rdd/SplitsPruningRDD.scala (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e0d2eabb1d..3d93ff33bb 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -40,6 +40,7 @@ import spark.rdd.MapPartitionsRDD import spark.rdd.MapPartitionsWithSplitRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD +import spark.rdd.SplitsPruningRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD import spark.storage.StorageLevel @@ -543,6 +544,15 @@ abstract class RDD[T: ClassManifest]( map(x => (f(x), x)) } + /** + * Prune splits (partitions) so Spark can avoid launching tasks on + * all splits. An example use case: If we know the RDD is partitioned by range, + * and the execution DAG has a filter on the key, we can avoid launching tasks + * on splits that don't have the range covering the key. + */ + def pruneSplits(splitsFilterFunc: Int => Boolean): RDD[T] = + new SplitsPruningRDD(this, splitsFilterFunc) + /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala new file mode 100644 index 0000000000..74e10265fc --- /dev/null +++ b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala @@ -0,0 +1,24 @@ +package spark.rdd + +import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} + +/** + * A RDD used to prune RDD splits so we can avoid launching tasks on + * all splits. An example use case: If we know the RDD is partitioned by range, + * and the execution DAG has a filter on the key, we can avoid launching tasks + * on splits that don't have the range covering the key. + */ +class SplitsPruningRDD[T: ClassManifest]( + prev: RDD[T], + @transient splitsFilterFunc: Int => Boolean) + extends RDD[T](prev) { + + @transient + val _splits: Array[Split] = prev.splits.filter(s => splitsFilterFunc(s.index)) + + override def compute(split: Split, context: TaskContext) = prev.iterator(split, context) + + override protected def getSplits = _splits + + override val partitioner = prev.partitioner +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index db217f8482..03aa2845f4 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -1,11 +1,9 @@ package spark import scala.collection.mutable.HashMap -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - +import org.scalatest.{BeforeAndAfter, FunSuite} +import spark.SparkContext._ import spark.rdd.CoalescedRDD -import SparkContext._ class RDDSuite extends FunSuite with BeforeAndAfter { @@ -104,7 +102,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { } test("caching with failures") { - sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val onlySplit = new Split { override def index: Int = 0 } var shouldFail = true val rdd = new RDD[Int](sc, Nil) { @@ -136,8 +134,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter { List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) // Check that the narrow dependency is also specified correctly - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4)) - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === + List(0, 1, 2, 3, 4)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === + List(5, 6, 7, 8, 9)) val coalesced2 = new CoalescedRDD(data, 3) assert(coalesced2.collect().toList === (1 to 10).toList) @@ -168,4 +168,12 @@ class RDDSuite extends FunSuite with BeforeAndAfter { nums.zip(sc.parallelize(1 to 4, 1)).collect() } } + + test("split pruning") { + sc = new SparkContext("local", "test") + val data = sc.parallelize(1 to 10, 10) + // Note that split number starts from 0, so > 8 means only 10th partition left. + val prunedData = data.pruneSplits(splitNum => splitNum > 8).collect + assert(prunedData.size == 1 && prunedData(0) == 10) + } } -- cgit v1.2.3 From c24b3819dd474e13d6098150c174b2e7e4bc6498 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 15:34:59 -0800 Subject: Added an extra assert for split size check. --- core/src/test/scala/spark/RDDSuite.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 03aa2845f4..ef74c99246 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -173,7 +173,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter { sc = new SparkContext("local", "test") val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. - val prunedData = data.pruneSplits(splitNum => splitNum > 8).collect - assert(prunedData.size == 1 && prunedData(0) == 10) + val prunedRdd = data.pruneSplits(splitNum => splitNum > 8) + assert(prunedRdd.splits.size == 1) + val prunedData = prunedRdd.collect + assert(prunedData.size == 1) + assert(prunedData(0) == 10) } } -- cgit v1.2.3 From 45cd50d5fe40869cdc237157e073cfb5ac47b27c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 16:06:58 -0800 Subject: Updated assert == to ===. --- core/src/test/scala/spark/RDDSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index ef74c99246..5a3a12dfff 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -174,9 +174,9 @@ class RDDSuite extends FunSuite with BeforeAndAfter { val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. val prunedRdd = data.pruneSplits(splitNum => splitNum > 8) - assert(prunedRdd.splits.size == 1) + assert(prunedRdd.splits.size === 1) val prunedData = prunedRdd.collect - assert(prunedData.size == 1) - assert(prunedData(0) == 10) + assert(prunedData.size === 1) + assert(prunedData(0) === 10) } } -- cgit v1.2.3 From 636e912f3289e422be9550752f5279d519062b75 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 21:21:55 -0800 Subject: Created a PruneDependency to properly assign dependency for SplitsPruningRDD. --- core/src/main/scala/spark/Dependency.scala | 24 +++++++++++++++++++--- .../main/scala/spark/rdd/SplitsPruningRDD.scala | 8 ++++---- 2 files changed, 25 insertions(+), 7 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index b85d2732db..7d5858e88e 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -5,6 +5,7 @@ package spark */ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable + /** * Base class for dependencies where each partition of the parent RDD is used by at most one * partition of the child RDD. Narrow dependencies allow for pipelined execution. @@ -12,12 +13,13 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { /** * Get the parent partitions for a child partition. - * @param outputPartition a partition of the child RDD + * @param partitionId a partition of the child RDD * @return the partitions of the parent RDD that the child partition depends upon */ - def getParents(outputPartition: Int): Seq[Int] + def getParents(partitionId: Int): Seq[Int] } + /** * Represents a dependency on the output of a shuffle stage. * @param shuffleId the shuffle id @@ -32,6 +34,7 @@ class ShuffleDependency[K, V]( val shuffleId: Int = rdd.context.newShuffleId() } + /** * Represents a one-to-one dependency between partitions of the parent and child RDDs. */ @@ -39,6 +42,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { override def getParents(partitionId: Int) = List(partitionId) } + /** * Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs. * @param rdd the parent RDD @@ -48,7 +52,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { */ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) extends NarrowDependency[T](rdd) { - + override def getParents(partitionId: Int) = { if (partitionId >= outStart && partitionId < outStart + length) { List(partitionId - outStart + inStart) @@ -57,3 +61,17 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) } } } + + +/** + * Represents a dependency between the SplitsPruningRDD and its parent. In this + * case, the child RDD contains a subset of splits of the parents'. + */ +class PruneDependency[T](rdd: RDD[T], @transient splitsFilterFunc: Int => Boolean) + extends NarrowDependency[T](rdd) { + + @transient + val splits: Array[Split] = rdd.splits.filter(s => splitsFilterFunc(s.index)) + + override def getParents(partitionId: Int) = List(splits(partitionId).index) +} diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala index 74e10265fc..7b44d85bb5 100644 --- a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala @@ -1,6 +1,6 @@ package spark.rdd -import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} +import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} /** * A RDD used to prune RDD splits so we can avoid launching tasks on @@ -11,12 +11,12 @@ import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} class SplitsPruningRDD[T: ClassManifest]( prev: RDD[T], @transient splitsFilterFunc: Int => Boolean) - extends RDD[T](prev) { + extends RDD[T](prev.context, List(new PruneDependency(prev, splitsFilterFunc))) { @transient - val _splits: Array[Split] = prev.splits.filter(s => splitsFilterFunc(s.index)) + val _splits: Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits - override def compute(split: Split, context: TaskContext) = prev.iterator(split, context) + override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) override protected def getSplits = _splits -- cgit v1.2.3 From 81004b967e838fca0790727a3fea5a265ddbc69a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 21:54:27 -0800 Subject: Marked prev RDD as transient in SplitsPruningRDD. --- core/src/main/scala/spark/rdd/SplitsPruningRDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala index 7b44d85bb5..9b1a210ba3 100644 --- a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala @@ -9,7 +9,7 @@ import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} * on splits that don't have the range covering the key. */ class SplitsPruningRDD[T: ClassManifest]( - prev: RDD[T], + @transient prev: RDD[T], @transient splitsFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, splitsFilterFunc))) { @@ -20,5 +20,5 @@ class SplitsPruningRDD[T: ClassManifest]( override protected def getSplits = _splits - override val partitioner = prev.partitioner + override val partitioner = firstParent[T].partitioner } -- cgit v1.2.3 From eedc542a0276a5248c81446ee84f56d691e5f488 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 22:14:23 -0800 Subject: Removed pruneSplits method in RDD and renamed SplitsPruningRDD to PartitionPruningRDD. --- core/src/main/scala/spark/RDD.scala | 10 --------- .../main/scala/spark/rdd/PartitionPruningRDD.scala | 24 ++++++++++++++++++++++ .../main/scala/spark/rdd/SplitsPruningRDD.scala | 24 ---------------------- core/src/test/scala/spark/RDDSuite.scala | 6 +++--- 4 files changed, 27 insertions(+), 37 deletions(-) create mode 100644 core/src/main/scala/spark/rdd/PartitionPruningRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/SplitsPruningRDD.scala (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 3d93ff33bb..e0d2eabb1d 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -40,7 +40,6 @@ import spark.rdd.MapPartitionsRDD import spark.rdd.MapPartitionsWithSplitRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD -import spark.rdd.SplitsPruningRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD import spark.storage.StorageLevel @@ -544,15 +543,6 @@ abstract class RDD[T: ClassManifest]( map(x => (f(x), x)) } - /** - * Prune splits (partitions) so Spark can avoid launching tasks on - * all splits. An example use case: If we know the RDD is partitioned by range, - * and the execution DAG has a filter on the key, we can avoid launching tasks - * on splits that don't have the range covering the key. - */ - def pruneSplits(splitsFilterFunc: Int => Boolean): RDD[T] = - new SplitsPruningRDD(this, splitsFilterFunc) - /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala new file mode 100644 index 0000000000..3048949ef2 --- /dev/null +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -0,0 +1,24 @@ +package spark.rdd + +import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} + +/** + * A RDD used to prune RDD partitions/splits so we can avoid launching tasks on + * all partitions. An example use case: If we know the RDD is partitioned by range, + * and the execution DAG has a filter on the key, we can avoid launching tasks + * on partitions that don't have the range covering the key. + */ +class PartitionPruningRDD[T: ClassManifest]( + @transient prev: RDD[T], + @transient partitionFilterFunc: Int => Boolean) + extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { + + @transient + val partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits + + override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) + + override protected def getSplits = partitions_ + + override val partitioner = firstParent[T].partitioner +} diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala deleted file mode 100644 index 9b1a210ba3..0000000000 --- a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala +++ /dev/null @@ -1,24 +0,0 @@ -package spark.rdd - -import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} - -/** - * A RDD used to prune RDD splits so we can avoid launching tasks on - * all splits. An example use case: If we know the RDD is partitioned by range, - * and the execution DAG has a filter on the key, we can avoid launching tasks - * on splits that don't have the range covering the key. - */ -class SplitsPruningRDD[T: ClassManifest]( - @transient prev: RDD[T], - @transient splitsFilterFunc: Int => Boolean) - extends RDD[T](prev.context, List(new PruneDependency(prev, splitsFilterFunc))) { - - @transient - val _splits: Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits - - override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) - - override protected def getSplits = _splits - - override val partitioner = firstParent[T].partitioner -} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 5a3a12dfff..73846131a9 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -3,7 +3,7 @@ package spark import scala.collection.mutable.HashMap import org.scalatest.{BeforeAndAfter, FunSuite} import spark.SparkContext._ -import spark.rdd.CoalescedRDD +import spark.rdd.{CoalescedRDD, PartitionPruningRDD} class RDDSuite extends FunSuite with BeforeAndAfter { @@ -169,11 +169,11 @@ class RDDSuite extends FunSuite with BeforeAndAfter { } } - test("split pruning") { + test("partition pruning") { sc = new SparkContext("local", "test") val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. - val prunedRdd = data.pruneSplits(splitNum => splitNum > 8) + val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) assert(prunedRdd.splits.size === 1) val prunedData = prunedRdd.collect assert(prunedData.size === 1) -- cgit v1.2.3 From c109f29c97c9606dee45e6300d01a272dbb560aa Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 22:22:03 -0800 Subject: Updated PruneDependency to change "split" to "partition". --- core/src/main/scala/spark/Dependency.scala | 10 +++++----- core/src/main/scala/spark/rdd/PartitionPruningRDD.scala | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index 7d5858e88e..647aee6eb5 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -64,14 +64,14 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) /** - * Represents a dependency between the SplitsPruningRDD and its parent. In this - * case, the child RDD contains a subset of splits of the parents'. + * Represents a dependency between the PartitionPruningRDD and its parent. In this + * case, the child RDD contains a subset of partitions of the parents'. */ -class PruneDependency[T](rdd: RDD[T], @transient splitsFilterFunc: Int => Boolean) +class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) extends NarrowDependency[T](rdd) { @transient - val splits: Array[Split] = rdd.splits.filter(s => splitsFilterFunc(s.index)) + val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index)) - override def getParents(partitionId: Int) = List(splits(partitionId).index) + override def getParents(partitionId: Int) = List(partitions(partitionId).index) } diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 3048949ef2..787b59ae8c 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -14,7 +14,7 @@ class PartitionPruningRDD[T: ClassManifest]( extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { @transient - val partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits + val partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) -- cgit v1.2.3 From 67a43bc7e622e4dd9d53ccf80b441740d6ff4df5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 23:06:52 -0800 Subject: Added a clearDependencies method in PartitionPruningRDD. --- core/src/main/scala/spark/rdd/PartitionPruningRDD.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 787b59ae8c..97dd37950e 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -14,11 +14,16 @@ class PartitionPruningRDD[T: ClassManifest]( extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { @transient - val partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions + var partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) override protected def getSplits = partitions_ override val partitioner = firstParent[T].partitioner + + override def clearDependencies() { + super.clearDependencies() + partitions_ = null + } } -- cgit v1.2.3 From 230bda204778e6f3c0f5a20ad341f643146d97cb Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 15 Jan 2013 14:01:19 -0600 Subject: Add LocalSparkContext to manage common sc variable. --- core/src/test/scala/spark/AccumulatorSuite.scala | 32 ++-------- core/src/test/scala/spark/BroadcastSuite.scala | 14 +---- core/src/test/scala/spark/CheckpointSuite.scala | 19 ++---- .../src/test/scala/spark/ClosureCleanerSuite.scala | 73 ++++++++++------------ core/src/test/scala/spark/DistributedSuite.scala | 23 ++----- core/src/test/scala/spark/FailureSuite.scala | 14 +---- core/src/test/scala/spark/FileServerSuite.scala | 16 ++--- core/src/test/scala/spark/FileSuite.scala | 16 +---- core/src/test/scala/spark/LocalSparkContext.scala | 41 ++++++++++++ .../test/scala/spark/MapOutputTrackerSuite.scala | 7 +-- core/src/test/scala/spark/PartitioningSuite.scala | 15 +---- core/src/test/scala/spark/PipedRDDSuite.scala | 16 +---- core/src/test/scala/spark/RDDSuite.scala | 14 +---- core/src/test/scala/spark/ShuffleSuite.scala | 14 +---- core/src/test/scala/spark/SortingSuite.scala | 13 +--- core/src/test/scala/spark/ThreadingSuite.scala | 14 +---- .../scala/spark/scheduler/TaskContextSuite.scala | 14 +---- 17 files changed, 109 insertions(+), 246 deletions(-) create mode 100644 core/src/test/scala/spark/LocalSparkContext.scala (limited to 'core') diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala index d8be99dde7..78d64a44ae 100644 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -1,6 +1,5 @@ package spark -import org.scalatest.BeforeAndAfter import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers import collection.mutable @@ -9,18 +8,7 @@ import scala.math.exp import scala.math.signum import spark.SparkContext._ -class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { - - var sc: SparkContext = null - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext { test ("basic accumulation"){ sc = new SparkContext("local", "test") @@ -53,10 +41,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter for (i <- 1 to maxI) { v should contain(i) } - sc.stop() - sc = null - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + resetSparkContext() } } @@ -86,10 +71,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter x => acc.value += x } } should produce [SparkException] - sc.stop() - sc = null - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + resetSparkContext() } } @@ -115,10 +97,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter bufferAcc.value should contain(i) mapAcc.value should contain (i -> i.toString) } - sc.stop() - sc = null - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + resetSparkContext() } } @@ -134,8 +113,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter x => acc.localValue ++= x } acc.value should be ( (0 to maxI).toSet) - sc.stop() - sc = null + resetSparkContext() } } diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala index 2d3302f0aa..362a31fb0d 100644 --- a/core/src/test/scala/spark/BroadcastSuite.scala +++ b/core/src/test/scala/spark/BroadcastSuite.scala @@ -1,20 +1,8 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter -class BroadcastSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class BroadcastSuite extends FunSuite with LocalSparkContext { test("basic broadcast") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 51573254ca..33c317720c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -1,34 +1,27 @@ package spark -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.FunSuite import java.io.File import spark.rdd._ import spark.SparkContext._ import storage.StorageLevel -class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { +class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { initLogging() - var sc: SparkContext = _ var checkpointDir: File = _ val partitioner = new HashPartitioner(2) - before { + override def beforeEach() { + super.beforeEach() checkpointDir = File.createTempFile("temp", "") checkpointDir.delete() - sc = new SparkContext("local", "test") sc.setCheckpointDir(checkpointDir.toString) } - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - + override def afterEach() { + super.afterEach() if (checkpointDir != null) { checkpointDir.delete() } diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala index dfa2de80e6..b2d0dd4627 100644 --- a/core/src/test/scala/spark/ClosureCleanerSuite.scala +++ b/core/src/test/scala/spark/ClosureCleanerSuite.scala @@ -3,6 +3,7 @@ package spark import java.io.NotSerializableException import org.scalatest.FunSuite +import spark.LocalSparkContext._ import SparkContext._ class ClosureCleanerSuite extends FunSuite { @@ -43,13 +44,10 @@ object TestObject { def run(): Int = { var nonSer = new NonSerializable var x = 5 - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - val answer = nums.map(_ + x).reduce(_ + _) - sc.stop() - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - return answer + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + x).reduce(_ + _) + } } } @@ -60,11 +58,10 @@ class TestClass extends Serializable { def run(): Int = { var nonSer = new NonSerializable - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - val answer = nums.map(_ + getX).reduce(_ + _) - sc.stop() - return answer + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + getX).reduce(_ + _) + } } } @@ -73,11 +70,10 @@ class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { def run(): Int = { var nonSer = new NonSerializable - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - val answer = nums.map(_ + getX).reduce(_ + _) - sc.stop() - return answer + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + getX).reduce(_ + _) + } } } @@ -89,11 +85,10 @@ class TestClassWithoutFieldAccess { def run(): Int = { var nonSer2 = new NonSerializable var x = 5 - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - val answer = nums.map(_ + x).reduce(_ + _) - sc.stop() - return answer + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + x).reduce(_ + _) + } } } @@ -102,16 +97,16 @@ object TestObjectWithNesting { def run(): Int = { var nonSer = new NonSerializable var answer = 0 - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - var y = 1 - for (i <- 1 to 4) { - var nonSer2 = new NonSerializable - var x = i - answer += nums.map(_ + x + y).reduce(_ + _) + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + var y = 1 + for (i <- 1 to 4) { + var nonSer2 = new NonSerializable + var x = i + answer += nums.map(_ + x + y).reduce(_ + _) + } + answer } - sc.stop() - return answer } } @@ -121,14 +116,14 @@ class TestClassWithNesting(val y: Int) extends Serializable { def run(): Int = { var nonSer = new NonSerializable var answer = 0 - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - for (i <- 1 to 4) { - var nonSer2 = new NonSerializable - var x = i - answer += nums.map(_ + x + getY).reduce(_ + _) + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + for (i <- 1 to 4) { + var nonSer2 = new NonSerializable + var x = i + answer += nums.map(_ + x + getY).reduce(_ + _) + } + answer } - sc.stop() - return answer } } diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index cacc2796b6..83a2a549a9 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -15,41 +15,28 @@ import scala.collection.mutable.ArrayBuffer import SparkContext._ import storage.StorageLevel -class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { +class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" - @transient var sc: SparkContext = _ - after { - if (sc != null) { - sc.stop() - sc = null - } System.clearProperty("spark.reducer.maxMbInFlight") System.clearProperty("spark.storage.memoryFraction") - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") } test("local-cluster format") { sc = new SparkContext("local-cluster[2,1,512]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) - sc.stop() - System.clearProperty("spark.master.port") + resetSparkContext() sc = new SparkContext("local-cluster[2 , 1 , 512]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) - sc.stop() - System.clearProperty("spark.master.port") + resetSparkContext() sc = new SparkContext("local-cluster[2, 1, 512]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) - sc.stop() - System.clearProperty("spark.master.port") + resetSparkContext() sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) - sc.stop() - System.clearProperty("spark.master.port") - sc = null + resetSparkContext() } test("simple groupByKey") { diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala index a3454f25f6..8c1445a465 100644 --- a/core/src/test/scala/spark/FailureSuite.scala +++ b/core/src/test/scala/spark/FailureSuite.scala @@ -1,7 +1,6 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import org.scalatest.prop.Checkers import scala.collection.mutable.ArrayBuffer @@ -23,18 +22,7 @@ object FailureSuiteState { } } -class FailureSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class FailureSuite extends FunSuite with LocalSparkContext { // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala index b4283d9604..8215cbde02 100644 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -2,17 +2,16 @@ package spark import com.google.common.io.Files import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import java.io.{File, PrintWriter, FileReader, BufferedReader} import SparkContext._ -class FileServerSuite extends FunSuite with BeforeAndAfter { +class FileServerSuite extends FunSuite with LocalSparkContext { - @transient var sc: SparkContext = _ @transient var tmpFile : File = _ @transient var testJarFile : File = _ - before { + override def beforeEach() { + super.beforeEach() // Create a sample text file val tmpdir = new File(Files.createTempDir(), "test") tmpdir.mkdir() @@ -22,17 +21,12 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { pw.close() } - after { - if (sc != null) { - sc.stop() - sc = null - } + override def afterEach() { + super.afterEach() // Clean up downloaded file if (tmpFile.exists) { tmpFile.delete() } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") } test("Distributing files locally") { diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index 554bea53a9..91b48c7456 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -6,24 +6,12 @@ import scala.io.Source import com.google.common.io.Files import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import org.apache.hadoop.io._ import SparkContext._ -class FileSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } - +class FileSuite extends FunSuite with LocalSparkContext { + test("text files") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala new file mode 100644 index 0000000000..b5e31ddae3 --- /dev/null +++ b/core/src/test/scala/spark/LocalSparkContext.scala @@ -0,0 +1,41 @@ +package spark + +import org.scalatest.Suite +import org.scalatest.BeforeAndAfterEach + +/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ +trait LocalSparkContext extends BeforeAndAfterEach { self: Suite => + + @transient var sc: SparkContext = _ + + override def afterEach() { + resetSparkContext() + super.afterEach() + } + + def resetSparkContext() = { + if (sc != null) { + LocalSparkContext.stop(sc) + sc = null + } + } + +} + +object LocalSparkContext { + def stop(sc: SparkContext) { + sc.stop() + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + } + + /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ + def withSpark[T](sc: SparkContext)(f: SparkContext => T) = { + try { + f(sc) + } finally { + stop(sc) + } + } + +} \ No newline at end of file diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index d3dd3a8fa4..774bbd65b1 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -1,17 +1,13 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import akka.actor._ import spark.scheduler.MapStatus import spark.storage.BlockManagerId import spark.util.AkkaUtils -class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { - after { - System.clearProperty("spark.master.port") - } +class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("compressSize") { assert(MapOutputTracker.compressSize(0L) === 0) @@ -81,7 +77,6 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { } test("remote fetch") { - System.clearProperty("spark.master.host") val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) System.setProperty("spark.master.port", boundPort.toString) diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index eb3c8f238f..af1107cd19 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -1,25 +1,12 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import scala.collection.mutable.ArrayBuffer import SparkContext._ -class PartitioningSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if(sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } - +class PartitioningSuite extends FunSuite with LocalSparkContext { test("HashPartitioner equality") { val p2 = new HashPartitioner(2) diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala index 9b84b29227..a6344edf8f 100644 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -1,21 +1,9 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import SparkContext._ -class PipedRDDSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class PipedRDDSuite extends FunSuite with LocalSparkContext { test("basic pipe") { sc = new SparkContext("local", "test") @@ -51,5 +39,3 @@ class PipedRDDSuite extends FunSuite with BeforeAndAfter { } } - - diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index db217f8482..592427e97a 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -2,23 +2,11 @@ package spark import scala.collection.mutable.HashMap import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import spark.rdd.CoalescedRDD import SparkContext._ -class RDDSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class RDDSuite extends FunSuite with LocalSparkContext { test("basic operations") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index bebb8ebe86..3493b9511f 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -3,7 +3,6 @@ package spark import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import org.scalatest.matchers.ShouldMatchers import org.scalatest.prop.Checkers import org.scalacheck.Arbitrary._ @@ -15,18 +14,7 @@ import com.google.common.io.Files import spark.rdd.ShuffledRDD import spark.SparkContext._ -class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { test("groupByKey") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala index 1ad11ff4c3..edb8c839fc 100644 --- a/core/src/test/scala/spark/SortingSuite.scala +++ b/core/src/test/scala/spark/SortingSuite.scala @@ -5,18 +5,7 @@ import org.scalatest.BeforeAndAfter import org.scalatest.matchers.ShouldMatchers import SparkContext._ -class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with Logging { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging { test("sortByKey") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/ThreadingSuite.scala b/core/src/test/scala/spark/ThreadingSuite.scala index e9b1837d89..ff315b6693 100644 --- a/core/src/test/scala/spark/ThreadingSuite.scala +++ b/core/src/test/scala/spark/ThreadingSuite.scala @@ -22,19 +22,7 @@ object ThreadingSuiteState { } } -class ThreadingSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if(sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } - +class ThreadingSuite extends FunSuite with LocalSparkContext { test("accessing SparkContext form a different thread") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala index ba6f8b588f..a5db7103f5 100644 --- a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala @@ -6,19 +6,9 @@ import spark.TaskContext import spark.RDD import spark.SparkContext import spark.Split +import spark.LocalSparkContext -class TaskContextSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { test("Calls executeOnCompleteCallbacks after failure") { var completed = false -- cgit v1.2.3 From b6fc6e67521e8a9a5291693cce3dc766da244395 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 24 Jan 2013 14:28:05 -0800 Subject: SPARK-541: Adding a warning for invalid Master URL Right now Spark silently parses master URL's which do not match any known regex as a Mesos URL. The Mesos error message when an invalid URL gets passed is really confusing, so this warns the user when the implicit conversion is happening. --- core/src/main/scala/spark/SparkContext.scala | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 66bdbe7cda..bc9fdee8b6 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -112,6 +112,8 @@ class SparkContext( val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters val SPARK_REGEX = """(spark://.*)""".r + //Regular expression for connection to Mesos cluster + val MESOS_REGEX = """(mesos://.*)""".r master match { case "local" => @@ -152,6 +154,9 @@ class SparkContext( scheduler case _ => + if (MESOS_REGEX.findFirstIn(master).isEmpty) { + logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) + } MesosNativeLibrary.load() val scheduler = new ClusterScheduler(this) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean -- cgit v1.2.3 From 7dfb82a992d47491174d7929e31351d26cadfcda Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 15:25:41 -0600 Subject: Replace old 'master' term with 'driver'. --- bagel/src/test/scala/bagel/BagelSuite.scala | 2 +- core/src/main/scala/spark/MapOutputTracker.scala | 10 +-- core/src/main/scala/spark/SparkContext.scala | 20 +++--- core/src/main/scala/spark/SparkEnv.scala | 22 +++---- .../spark/broadcast/BitTorrentBroadcast.scala | 24 +++---- .../src/main/scala/spark/broadcast/Broadcast.scala | 6 +- .../scala/spark/broadcast/BroadcastFactory.scala | 4 +- .../main/scala/spark/broadcast/HttpBroadcast.scala | 6 +- .../main/scala/spark/broadcast/MultiTracker.scala | 35 +++++----- .../main/scala/spark/broadcast/TreeBroadcast.scala | 52 +++++++-------- .../scala/spark/deploy/LocalSparkCluster.scala | 34 +++++----- .../scala/spark/deploy/client/ClientListener.scala | 4 +- .../main/scala/spark/deploy/master/JobInfo.scala | 2 +- .../main/scala/spark/deploy/master/Master.scala | 18 +++--- .../spark/executor/StandaloneExecutorBackend.scala | 26 ++++---- .../cluster/SparkDeploySchedulerBackend.scala | 33 +++++----- .../cluster/StandaloneClusterMessage.scala | 8 +-- .../cluster/StandaloneSchedulerBackend.scala | 74 +++++++++++----------- .../mesos/CoarseMesosSchedulerBackend.scala | 6 +- .../scala/spark/storage/BlockManagerMaster.scala | 69 ++++++++++---------- .../main/scala/spark/storage/ThreadingTest.scala | 6 +- core/src/test/scala/spark/JavaAPISuite.java | 2 +- core/src/test/scala/spark/LocalSparkContext.scala | 2 +- .../test/scala/spark/MapOutputTrackerSuite.scala | 2 +- docs/configuration.md | 12 ++-- python/pyspark/tests.py | 2 +- repl/src/test/scala/spark/repl/ReplSuite.scala | 2 +- .../streaming/dstream/NetworkInputDStream.scala | 4 +- .../test/java/spark/streaming/JavaAPISuite.java | 2 +- .../spark/streaming/BasicOperationsSuite.scala | 2 +- .../scala/spark/streaming/CheckpointSuite.scala | 2 +- .../test/scala/spark/streaming/FailureSuite.scala | 2 +- .../scala/spark/streaming/InputStreamsSuite.scala | 2 +- .../spark/streaming/WindowOperationsSuite.scala | 2 +- 34 files changed, 248 insertions(+), 251 deletions(-) (limited to 'core') diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala index ca59f46843..3c2f9c4616 100644 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/bagel/BagelSuite.scala @@ -23,7 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter { sc = null } // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } test("halting by voting") { diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index ac02f3363a..d4f5164f7d 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -38,10 +38,7 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac } } -private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging { - val ip: String = System.getProperty("spark.master.host", "localhost") - val port: Int = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "MapOutputTracker" +private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolean) extends Logging { val timeout = 10.seconds @@ -56,11 +53,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea var cacheGeneration = generation val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] - var trackerActor: ActorRef = if (isMaster) { + val actorName: String = "MapOutputTracker" + var trackerActor: ActorRef = if (isDriver) { val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName) logInfo("Registered MapOutputTrackerActor actor") actor } else { + val ip = System.getProperty("spark.driver.host", "localhost") + val port = System.getProperty("spark.driver.port", "7077").toInt val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) actorSystem.actorFor(url) } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bc9fdee8b6..d4991cb1e0 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -66,20 +66,20 @@ class SparkContext( // Ensure logging is initialized before we spawn any threads initLogging() - // Set Spark master host and port system properties - if (System.getProperty("spark.master.host") == null) { - System.setProperty("spark.master.host", Utils.localIpAddress) + // Set Spark driver host and port system properties + if (System.getProperty("spark.driver.host") == null) { + System.setProperty("spark.driver.host", Utils.localIpAddress) } - if (System.getProperty("spark.master.port") == null) { - System.setProperty("spark.master.port", "0") + if (System.getProperty("spark.driver.port") == null) { + System.setProperty("spark.driver.port", "0") } private val isLocal = (master == "local" || master.startsWith("local[")) // Create the Spark execution environment (cache, map output tracker, etc) private[spark] val env = SparkEnv.createFromSystemProperties( - System.getProperty("spark.master.host"), - System.getProperty("spark.master.port").toInt, + System.getProperty("spark.driver.host"), + System.getProperty("spark.driver.port").toInt, true, isLocal) SparkEnv.set(env) @@ -396,14 +396,14 @@ class SparkContext( /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `+=` method. Only the driver can access the accumulator's `value`. */ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = new Accumulator(initialValue, param) /** * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. - * Only the master can access the accumuable's `value`. + * Only the driver can access the accumuable's `value`. * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ @@ -530,7 +530,7 @@ class SparkContext( /** * Run a function on a given set of partitions in an RDD and return the results. This is the main * entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies - * whether the scheduler can run the computation on the master rather than shipping it out to the + * whether the scheduler can run the computation on the driver rather than shipping it out to the * cluster, for short actions like first(). */ def runJob[T, U: ClassManifest]( diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 2a7a8af83d..4034af610c 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -60,15 +60,15 @@ object SparkEnv extends Logging { def createFromSystemProperties( hostname: String, port: Int, - isMaster: Boolean, + isDriver: Boolean, isLocal: Boolean ) : SparkEnv = { val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port) - // Bit of a hack: If this is the master and our port was 0 (meaning bind to any free port), - // figure out which port number Akka actually bound to and set spark.master.port to it. - if (isMaster && port == 0) { - System.setProperty("spark.master.port", boundPort.toString) + // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port), + // figure out which port number Akka actually bound to and set spark.driver.port to it. + if (isDriver && port == 0) { + System.setProperty("spark.driver.port", boundPort.toString) } val classLoader = Thread.currentThread.getContextClassLoader @@ -82,22 +82,22 @@ object SparkEnv extends Logging { val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") - val masterIp: String = System.getProperty("spark.master.host", "localhost") - val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt + val driverIp: String = System.getProperty("spark.driver.host", "localhost") + val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt val blockManagerMaster = new BlockManagerMaster( - actorSystem, isMaster, isLocal, masterIp, masterPort) + actorSystem, isDriver, isLocal, driverIp, driverPort) val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer) val connectionManager = blockManager.connectionManager - val broadcastManager = new BroadcastManager(isMaster) + val broadcastManager = new BroadcastManager(isDriver) val closureSerializer = instantiateClass[Serializer]( "spark.closure.serializer", "spark.JavaSerializer") val cacheManager = new CacheManager(blockManager) - val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster) + val mapOutputTracker = new MapOutputTracker(actorSystem, isDriver) val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") @@ -109,7 +109,7 @@ object SparkEnv extends Logging { // Set the sparkFiles directory, used when downloading dependencies. In local mode, // this is a temporary directory; in distributed mode, this is the executor's current working // directory. - val sparkFilesDir: String = if (isMaster) { + val sparkFilesDir: String = if (isDriver) { Utils.createTempDir().getAbsolutePath } else { "." diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala index 386f505f2a..adcb2d2415 100644 --- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala @@ -31,7 +31,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: @transient var totalBlocks = -1 @transient var hasBlocks = new AtomicInteger(0) - // Used ONLY by Master to track how many unique blocks have been sent out + // Used ONLY by driver to track how many unique blocks have been sent out @transient var sentBlocks = new AtomicInteger(0) @transient var listenPortLock = new Object @@ -42,7 +42,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: @transient var serveMR: ServeMultipleRequests = null - // Used only in Master + // Used only in driver @transient var guideMR: GuideMultipleRequests = null // Used only in Workers @@ -99,14 +99,14 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: } // Must always come AFTER listenPort is created - val masterSource = + val driverSource = SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) hasBlocksBitVector.synchronized { - masterSource.hasBlocksBitVector = hasBlocksBitVector + driverSource.hasBlocksBitVector = hasBlocksBitVector } // In the beginning, this is the only known source to Guide - listOfSources += masterSource + listOfSources += driverSource // Register with the Tracker MultiTracker.registerBroadcast(id, @@ -122,7 +122,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: case None => logInfo("Started reading broadcast variable " + id) - // Initializing everything because Master will only send null/0 values + // Initializing everything because driver will only send null/0 values // Only the 1st worker in a node can be here. Others will get from cache initializeWorkerVariables() @@ -151,7 +151,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: } } - // Initialize variables in the worker node. Master sends everything as 0/null + // Initialize variables in the worker node. Driver sends everything as 0/null private def initializeWorkerVariables() { arrayOfBlocks = null hasBlocksBitVector = null @@ -248,7 +248,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: // Receive source information from Guide var suitableSources = oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] - logDebug("Received suitableSources from Master " + suitableSources) + logDebug("Received suitableSources from Driver " + suitableSources) addToListOfSources(suitableSources) @@ -532,7 +532,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: oosSource.writeObject(blockToAskFor) oosSource.flush() - // CHANGED: Master might send some other block than the one + // CHANGED: Driver might send some other block than the one // requested to ensure fast spreading of all blocks. val recvStartTime = System.currentTimeMillis val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] @@ -982,9 +982,9 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: // Receive which block to send var blockToSend = ois.readObject.asInstanceOf[Int] - // If it is master AND at least one copy of each block has not been + // If it is driver AND at least one copy of each block has not been // sent out already, MODIFY blockToSend - if (MultiTracker.isMaster && sentBlocks.get < totalBlocks) { + if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) { blockToSend = sentBlocks.getAndIncrement } @@ -1031,7 +1031,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: private[spark] class BitTorrentBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) } + def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new BitTorrentBroadcast[T](value_, isLocal, id) diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index 2ffe7f741d..415bde5d67 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -15,7 +15,7 @@ abstract class Broadcast[T](private[spark] val id: Long) extends Serializable { } private[spark] -class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable { +class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable { private var initialized = false private var broadcastFactory: BroadcastFactory = null @@ -33,7 +33,7 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isMaster) + broadcastFactory.initialize(isDriver) initialized = true } @@ -49,5 +49,5 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl def newBroadcast[T](value_ : T, isLocal: Boolean) = broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) - def isMaster = isMaster_ + def isDriver = _isDriver } diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala index ab6d302827..5c6184c3c7 100644 --- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala @@ -7,7 +7,7 @@ package spark.broadcast * entire Spark job. */ private[spark] trait BroadcastFactory { - def initialize(isMaster: Boolean): Unit - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] + def initialize(isDriver: Boolean): Unit + def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 8e490e6bad..7e30b8f7d2 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -48,7 +48,7 @@ extends Broadcast[T](id) with Logging with Serializable { } private[spark] class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) } + def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new HttpBroadcast[T](value_, isLocal, id) @@ -69,12 +69,12 @@ private object HttpBroadcast extends Logging { private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup) - def initialize(isMaster: Boolean) { + def initialize(isDriver: Boolean) { synchronized { if (!initialized) { bufferSize = System.getProperty("spark.buffer.size", "65536").toInt compress = System.getProperty("spark.broadcast.compress", "true").toBoolean - if (isMaster) { + if (isDriver) { createServer() } serverUri = System.getProperty("spark.httpBroadcast.uri") diff --git a/core/src/main/scala/spark/broadcast/MultiTracker.scala b/core/src/main/scala/spark/broadcast/MultiTracker.scala index 5e76dedb94..3fd77af73f 100644 --- a/core/src/main/scala/spark/broadcast/MultiTracker.scala +++ b/core/src/main/scala/spark/broadcast/MultiTracker.scala @@ -23,25 +23,24 @@ extends Logging { var ranGen = new Random private var initialized = false - private var isMaster_ = false + private var _isDriver = false private var stopBroadcast = false private var trackMV: TrackMultipleValues = null - def initialize(isMaster__ : Boolean) { + def initialize(__isDriver: Boolean) { synchronized { if (!initialized) { + _isDriver = __isDriver - isMaster_ = isMaster__ - - if (isMaster) { + if (isDriver) { trackMV = new TrackMultipleValues trackMV.setDaemon(true) trackMV.start() - // Set masterHostAddress to the master's IP address for the slaves to read - System.setProperty("spark.MultiTracker.MasterHostAddress", Utils.localIpAddress) + // Set DriverHostAddress to the driver's IP address for the slaves to read + System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress) } initialized = true @@ -54,10 +53,10 @@ extends Logging { } // Load common parameters - private var MasterHostAddress_ = System.getProperty( - "spark.MultiTracker.MasterHostAddress", "") - private var MasterTrackerPort_ = System.getProperty( - "spark.broadcast.masterTrackerPort", "11111").toInt + private var DriverHostAddress_ = System.getProperty( + "spark.MultiTracker.DriverHostAddress", "") + private var DriverTrackerPort_ = System.getProperty( + "spark.broadcast.driverTrackerPort", "11111").toInt private var BlockSize_ = System.getProperty( "spark.broadcast.blockSize", "4096").toInt * 1024 private var MaxRetryCount_ = System.getProperty( @@ -91,11 +90,11 @@ extends Logging { private var EndGameFraction_ = System.getProperty( "spark.broadcast.endGameFraction", "0.95").toDouble - def isMaster = isMaster_ + def isDriver = _isDriver // Common config params - def MasterHostAddress = MasterHostAddress_ - def MasterTrackerPort = MasterTrackerPort_ + def DriverHostAddress = DriverHostAddress_ + def DriverTrackerPort = DriverTrackerPort_ def BlockSize = BlockSize_ def MaxRetryCount = MaxRetryCount_ @@ -123,7 +122,7 @@ extends Logging { var threadPool = Utils.newDaemonCachedThreadPool() var serverSocket: ServerSocket = null - serverSocket = new ServerSocket(MasterTrackerPort) + serverSocket = new ServerSocket(DriverTrackerPort) logInfo("TrackMultipleValues started at " + serverSocket) try { @@ -235,7 +234,7 @@ extends Logging { try { // Connect to the tracker to find out GuideInfo clientSocketToTracker = - new Socket(MultiTracker.MasterHostAddress, MultiTracker.MasterTrackerPort) + new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort) oosTracker = new ObjectOutputStream(clientSocketToTracker.getOutputStream) oosTracker.flush() @@ -276,7 +275,7 @@ extends Logging { } def registerBroadcast(id: Long, gInfo: SourceInfo) { - val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) + val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) val oosST = new ObjectOutputStream(socket.getOutputStream) oosST.flush() val oisST = new ObjectInputStream(socket.getInputStream) @@ -303,7 +302,7 @@ extends Logging { } def unregisterBroadcast(id: Long) { - val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) + val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) val oosST = new ObjectOutputStream(socket.getOutputStream) oosST.flush() val oisST = new ObjectInputStream(socket.getInputStream) diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala index f573512835..c55c476117 100644 --- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala @@ -98,7 +98,7 @@ extends Broadcast[T](id) with Logging with Serializable { case None => logInfo("Started reading broadcast variable " + id) - // Initializing everything because Master will only send null/0 values + // Initializing everything because Driver will only send null/0 values // Only the 1st worker in a node can be here. Others will get from cache initializeWorkerVariables() @@ -157,55 +157,55 @@ extends Broadcast[T](id) with Logging with Serializable { listenPortLock.synchronized { listenPortLock.wait() } } - var clientSocketToMaster: Socket = null - var oosMaster: ObjectOutputStream = null - var oisMaster: ObjectInputStream = null + var clientSocketToDriver: Socket = null + var oosDriver: ObjectOutputStream = null + var oisDriver: ObjectInputStream = null // Connect and receive broadcast from the specified source, retrying the // specified number of times in case of failures var retriesLeft = MultiTracker.MaxRetryCount do { - // Connect to Master and send this worker's Information - clientSocketToMaster = new Socket(MultiTracker.MasterHostAddress, gInfo.listenPort) - oosMaster = new ObjectOutputStream(clientSocketToMaster.getOutputStream) - oosMaster.flush() - oisMaster = new ObjectInputStream(clientSocketToMaster.getInputStream) + // Connect to Driver and send this worker's Information + clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort) + oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream) + oosDriver.flush() + oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream) - logDebug("Connected to Master's guiding object") + logDebug("Connected to Driver's guiding object") // Send local source information - oosMaster.writeObject(SourceInfo(hostAddress, listenPort)) - oosMaster.flush() + oosDriver.writeObject(SourceInfo(hostAddress, listenPort)) + oosDriver.flush() - // Receive source information from Master - var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] + // Receive source information from Driver + var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo] totalBlocks = sourceInfo.totalBlocks arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } totalBytes = sourceInfo.totalBytes - logDebug("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) + logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort) val start = System.nanoTime val receptionSucceeded = receiveSingleTransmission(sourceInfo) val time = (System.nanoTime - start) / 1e9 - // Updating some statistics in sourceInfo. Master will be using them later + // Updating some statistics in sourceInfo. Driver will be using them later if (!receptionSucceeded) { sourceInfo.receptionFailed = true } - // Send back statistics to the Master - oosMaster.writeObject(sourceInfo) + // Send back statistics to the Driver + oosDriver.writeObject(sourceInfo) - if (oisMaster != null) { - oisMaster.close() + if (oisDriver != null) { + oisDriver.close() } - if (oosMaster != null) { - oosMaster.close() + if (oosDriver != null) { + oosDriver.close() } - if (clientSocketToMaster != null) { - clientSocketToMaster.close() + if (clientSocketToDriver != null) { + clientSocketToDriver.close() } retriesLeft -= 1 @@ -552,7 +552,7 @@ extends Broadcast[T](id) with Logging with Serializable { } private def sendObject() { - // Wait till receiving the SourceInfo from Master + // Wait till receiving the SourceInfo from Driver while (totalBlocks == -1) { totalBlocksLock.synchronized { totalBlocksLock.wait() } } @@ -576,7 +576,7 @@ extends Broadcast[T](id) with Logging with Serializable { private[spark] class TreeBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) } + def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new TreeBroadcast[T](value_, isLocal, id) diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index 4211d80596..ae083efc8d 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -10,7 +10,7 @@ import spark.{Logging, Utils} import scala.collection.mutable.ArrayBuffer private[spark] -class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging { +class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { val localIpAddress = Utils.localIpAddress @@ -19,33 +19,31 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) var masterPort : Int = _ var masterUrl : String = _ - val slaveActorSystems = ArrayBuffer[ActorSystem]() - val slaveActors = ArrayBuffer[ActorRef]() + val workerActorSystems = ArrayBuffer[ActorSystem]() + val workerActors = ArrayBuffer[ActorRef]() def start() : String = { - logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.") + logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0) masterActorSystem = actorSystem masterUrl = "spark://" + localIpAddress + ":" + masterPort - val actor = masterActorSystem.actorOf( + masterActor = masterActorSystem.actorOf( Props(new Master(localIpAddress, masterPort, 0)), name = "Master") - masterActor = actor - /* Start the Slaves */ - for (slaveNum <- 1 to numSlaves) { - /* We can pretend to test distributed stuff by giving the slaves distinct hostnames. + /* Start the Workers */ + for (workerNum <- 1 to numWorkers) { + /* We can pretend to test distributed stuff by giving the workers distinct hostnames. All of 127/8 should be a loopback, we use 127.100.*.* in hopes that it is sufficiently distinctive. */ - val slaveIpAddress = "127.100.0." + (slaveNum % 256) + val workerIpAddress = "127.100.0." + (workerNum % 256) val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("sparkWorker" + slaveNum, slaveIpAddress, 0) - slaveActorSystems += actorSystem - val actor = actorSystem.actorOf( - Props(new Worker(slaveIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)), + AkkaUtils.createActorSystem("sparkWorker" + workerNum, workerIpAddress, 0) + workerActorSystems += actorSystem + workerActors += actorSystem.actorOf( + Props(new Worker(workerIpAddress, boundPort, 0, coresPerWorker, memoryPerWorker, masterUrl)), name = "Worker") - slaveActors += actor } return masterUrl @@ -53,9 +51,9 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) def stop() { logInfo("Shutting down local Spark cluster.") - // Stop the slaves before the master so they don't get upset that it disconnected - slaveActorSystems.foreach(_.shutdown()) - slaveActorSystems.foreach(_.awaitTermination()) + // Stop the workers before the master so they don't get upset that it disconnected + workerActorSystems.foreach(_.shutdown()) + workerActorSystems.foreach(_.awaitTermination()) masterActorSystem.shutdown() masterActorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala index da6abcc9c2..7035f4b394 100644 --- a/core/src/main/scala/spark/deploy/client/ClientListener.scala +++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala @@ -12,7 +12,7 @@ private[spark] trait ClientListener { def disconnected(): Unit - def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int): Unit + def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit - def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit + def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit } diff --git a/core/src/main/scala/spark/deploy/master/JobInfo.scala b/core/src/main/scala/spark/deploy/master/JobInfo.scala index 130b031a2a..a274b21c34 100644 --- a/core/src/main/scala/spark/deploy/master/JobInfo.scala +++ b/core/src/main/scala/spark/deploy/master/JobInfo.scala @@ -10,7 +10,7 @@ private[spark] class JobInfo( val id: String, val desc: JobDescription, val submitDate: Date, - val actor: ActorRef) + val driver: ActorRef) { var state = JobState.WAITING var executors = new mutable.HashMap[Int, ExecutorInfo] diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 2c2cd0231b..3347207c6d 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -88,7 +88,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor execOption match { case Some(exec) => { exec.state = state - exec.job.actor ! ExecutorUpdated(execId, state, message, exitStatus) + exec.job.driver ! ExecutorUpdated(execId, state, message, exitStatus) if (ExecutorState.isFinished(state)) { val jobInfo = idToJob(jobId) // Remove this executor from the worker and job @@ -199,7 +199,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) - exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) + exec.job.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, @@ -221,19 +221,19 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor actorToWorker -= worker.actor addressToWorker -= worker.actor.path.address for (exec <- worker.executors.values) { - exec.job.actor ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None) + exec.job.driver ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None) exec.job.executors -= exec.id } } - def addJob(desc: JobDescription, actor: ActorRef): JobInfo = { + def addJob(desc: JobDescription, driver: ActorRef): JobInfo = { val now = System.currentTimeMillis() val date = new Date(now) - val job = new JobInfo(now, newJobId(date), desc, date, actor) + val job = new JobInfo(now, newJobId(date), desc, date, driver) jobs += job idToJob(job.id) = job - actorToJob(sender) = job - addressToJob(sender.path.address) = job + actorToJob(driver) = job + addressToJob(driver.path.address) = job return job } @@ -242,8 +242,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor logInfo("Removing job " + job.id) jobs -= job idToJob -= job.id - actorToJob -= job.actor - addressToWorker -= job.actor.path.address + actorToJob -= job.driver + addressToWorker -= job.driver.path.address completedJobs += job // Remember it in our history waitingJobs -= job for (exec <- job.executors.values) { diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index a29bf974d2..f80f1b5274 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -16,33 +16,33 @@ import spark.scheduler.cluster.RegisterSlave private[spark] class StandaloneExecutorBackend( executor: Executor, - masterUrl: String, - slaveId: String, + driverUrl: String, + workerId: String, hostname: String, cores: Int) extends Actor with ExecutorBackend with Logging { - var master: ActorRef = null + var driver: ActorRef = null override def preStart() { try { - logInfo("Connecting to master: " + masterUrl) - master = context.actorFor(masterUrl) - master ! RegisterSlave(slaveId, hostname, cores) + logInfo("Connecting to driver: " + driverUrl) + driver = context.actorFor(driverUrl) + driver ! RegisterSlave(workerId, hostname, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing + context.watch(driver) // Doesn't work with remote actors, but useful for testing } catch { case e: Exception => - logError("Failed to connect to master", e) + logError("Failed to connect to driver", e) System.exit(1) } } override def receive = { case RegisteredSlave(sparkProperties) => - logInfo("Successfully registered with master") + logInfo("Successfully registered with driver") executor.initialize(hostname, sparkProperties) case RegisterSlaveFailed(message) => @@ -55,24 +55,24 @@ private[spark] class StandaloneExecutorBackend( } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - master ! StatusUpdate(slaveId, taskId, state, data) + driver ! StatusUpdate(workerId, taskId, state, data) } } private[spark] object StandaloneExecutorBackend { - def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) { + def run(driverUrl: String, workerId: String, hostname: String, cores: Int) { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(new Executor, masterUrl, slaveId, hostname, cores)), + Props(new StandaloneExecutorBackend(new Executor, driverUrl, workerId, hostname, cores)), name = "Executor") actorSystem.awaitTermination() } def main(args: Array[String]) { if (args.length != 4) { - System.err.println("Usage: StandaloneExecutorBackend ") + System.err.println("Usage: StandaloneExecutorBackend ") System.exit(1) } run(args(0), args(1), args(2), args(3).toInt) diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 4f82cd96dd..866beb6d01 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,7 +19,7 @@ private[spark] class SparkDeploySchedulerBackend( var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt - val executorIdToSlaveId = new HashMap[String, String] + val executorIdToWorkerId = new HashMap[String, String] // Memory used by each executor (in megabytes) val executorMemory = { @@ -34,10 +34,11 @@ private[spark] class SparkDeploySchedulerBackend( override def start() { super.start() - val masterUrl = "akka://spark@%s:%s/user/%s".format( - System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), + // The endpoint for executors to talk to us + val driverUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), StandaloneSchedulerBackend.ACTOR_NAME) - val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") + val args = Seq(driverUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome) @@ -55,35 +56,35 @@ private[spark] class SparkDeploySchedulerBackend( } } - def connected(jobId: String) { + override def connected(jobId: String) { logInfo("Connected to Spark cluster with job ID " + jobId) } - def disconnected() { + override def disconnected() { if (!stopping) { logError("Disconnected from Spark cluster!") scheduler.error("Disconnected from Spark cluster") } } - def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) { - executorIdToSlaveId += id -> workerId + override def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int) { + executorIdToWorkerId += fullId -> workerId logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format( - id, host, cores, Utils.memoryMegabytesToString(memory))) + fullId, host, cores, Utils.memoryMegabytesToString(memory))) } - def executorRemoved(id: String, message: String, exitStatus: Option[Int]) { + override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) { val reason: ExecutorLossReason = exitStatus match { case Some(code) => ExecutorExited(code) case None => SlaveLost(message) } - logInfo("Executor %s removed: %s".format(id, message)) - executorIdToSlaveId.get(id) match { - case Some(slaveId) => - executorIdToSlaveId.remove(id) - scheduler.slaveLost(slaveId, reason) + logInfo("Executor %s removed: %s".format(fullId, message)) + executorIdToWorkerId.get(fullId) match { + case Some(workerId) => + executorIdToWorkerId.remove(fullId) + scheduler.slaveLost(workerId, reason) case None => - logInfo("No slave ID known for executor %s".format(id)) + logInfo("No worker ID known for executor %s".format(fullId)) } } } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala index 1386cd9d44..bea9dc4f23 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -6,7 +6,7 @@ import spark.util.SerializableBuffer private[spark] sealed trait StandaloneClusterMessage extends Serializable -// Master to slaves +// Driver to executors private[spark] case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage @@ -16,7 +16,7 @@ case class RegisteredSlave(sparkProperties: Seq[(String, String)]) extends Stand private[spark] case class RegisterSlaveFailed(message: String) extends StandaloneClusterMessage -// Slaves to master +// Executors to driver private[spark] case class RegisterSlave(slaveId: String, host: String, cores: Int) extends StandaloneClusterMessage @@ -32,6 +32,6 @@ object StatusUpdate { } } -// Internal messages in master +// Internal messages in driver private[spark] case object ReviveOffers extends StandaloneClusterMessage -private[spark] case object StopMaster extends StandaloneClusterMessage +private[spark] case object StopDriver extends StandaloneClusterMessage diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index eeaae23dc8..d742a7b2bf 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -23,7 +23,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) - class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor { + class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { val slaveActor = new HashMap[String, ActorRef] val slaveAddress = new HashMap[String, Address] val slaveHost = new HashMap[String, String] @@ -37,34 +37,34 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } def receive = { - case RegisterSlave(slaveId, host, cores) => - if (slaveActor.contains(slaveId)) { - sender ! RegisterSlaveFailed("Duplicate slave ID: " + slaveId) + case RegisterSlave(workerId, host, cores) => + if (slaveActor.contains(workerId)) { + sender ! RegisterSlaveFailed("Duplicate slave ID: " + workerId) } else { - logInfo("Registered slave: " + sender + " with ID " + slaveId) + logInfo("Registered slave: " + sender + " with ID " + workerId) sender ! RegisteredSlave(sparkProperties) context.watch(sender) - slaveActor(slaveId) = sender - slaveHost(slaveId) = host - freeCores(slaveId) = cores - slaveAddress(slaveId) = sender.path.address - actorToSlaveId(sender) = slaveId - addressToSlaveId(sender.path.address) = slaveId + slaveActor(workerId) = sender + slaveHost(workerId) = host + freeCores(workerId) = cores + slaveAddress(workerId) = sender.path.address + actorToSlaveId(sender) = workerId + addressToSlaveId(sender.path.address) = workerId totalCoreCount.addAndGet(cores) makeOffers() } - case StatusUpdate(slaveId, taskId, state, data) => + case StatusUpdate(workerId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { - freeCores(slaveId) += 1 - makeOffers(slaveId) + freeCores(workerId) += 1 + makeOffers(workerId) } case ReviveOffers => makeOffers() - case StopMaster => + case StopDriver => sender ! true context.stop(self) @@ -85,9 +85,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } // Make fake resource offers on just one slave - def makeOffers(slaveId: String) { + def makeOffers(workerId: String) { launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(slaveId, slaveHost(slaveId), freeCores(slaveId))))) + Seq(new WorkerOffer(workerId, slaveHost(workerId), freeCores(workerId))))) } // Launch tasks returned by a set of resource offers @@ -99,24 +99,24 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } // Remove a disconnected slave from the cluster - def removeSlave(slaveId: String, reason: String) { - logInfo("Slave " + slaveId + " disconnected, so removing it") - val numCores = freeCores(slaveId) - actorToSlaveId -= slaveActor(slaveId) - addressToSlaveId -= slaveAddress(slaveId) - slaveActor -= slaveId - slaveHost -= slaveId - freeCores -= slaveId - slaveHost -= slaveId + def removeSlave(workerId: String, reason: String) { + logInfo("Slave " + workerId + " disconnected, so removing it") + val numCores = freeCores(workerId) + actorToSlaveId -= slaveActor(workerId) + addressToSlaveId -= slaveAddress(workerId) + slaveActor -= workerId + slaveHost -= workerId + freeCores -= workerId + slaveHost -= workerId totalCoreCount.addAndGet(-numCores) - scheduler.slaveLost(slaveId, SlaveLost(reason)) + scheduler.slaveLost(workerId, SlaveLost(reason)) } } - var masterActor: ActorRef = null + var driverActor: ActorRef = null val taskIdsOnSlave = new HashMap[String, HashSet[String]] - def start() { + override def start() { val properties = new ArrayBuffer[(String, String)] val iterator = System.getProperties.entrySet.iterator while (iterator.hasNext) { @@ -126,15 +126,15 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor properties += ((key, value)) } } - masterActor = actorSystem.actorOf( - Props(new MasterActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) + driverActor = actorSystem.actorOf( + Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) } - def stop() { + override def stop() { try { - if (masterActor != null) { + if (driverActor != null) { val timeout = 5.seconds - val future = masterActor.ask(StopMaster)(timeout) + val future = driverActor.ask(StopDriver)(timeout) Await.result(future, timeout) } } catch { @@ -143,11 +143,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } } - def reviveOffers() { - masterActor ! ReviveOffers + override def reviveOffers() { + driverActor ! ReviveOffers } - def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2) + override def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2) } private[spark] object StandaloneSchedulerBackend { diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index 014906b028..7bf56a05d6 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -104,11 +104,11 @@ private[spark] class CoarseMesosSchedulerBackend( def createCommand(offer: Offer, numCores: Int): CommandInfo = { val runScript = new File(sparkHome, "run").getCanonicalPath - val masterUrl = "akka://spark@%s:%s/user/%s".format( - System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), + val driverUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), StandaloneSchedulerBackend.ACTOR_NAME) val command = "\"%s\" spark.executor.StandaloneExecutorBackend %s %s %s %d".format( - runScript, masterUrl, offer.getSlaveId.getValue, offer.getHostname, numCores) + runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores) val environment = Environment.newBuilder() sc.executorEnvs.foreach { case (key, value) => environment.addVariables(Environment.Variable.newBuilder() diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index a3d8671834..9fd2b454a4 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -11,52 +11,51 @@ import akka.util.duration._ import spark.{Logging, SparkException, Utils} - private[spark] class BlockManagerMaster( val actorSystem: ActorSystem, - isMaster: Boolean, + isDriver: Boolean, isLocal: Boolean, - masterIp: String, - masterPort: Int) + driverIp: String, + driverPort: Int) extends Logging { val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt - val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager" + val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager" val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager" val DEFAULT_MANAGER_IP: String = Utils.localHostName() val timeout = 10.seconds - var masterActor: ActorRef = { - if (isMaster) { - val masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), - name = MASTER_AKKA_ACTOR_NAME) + var driverActor: ActorRef = { + if (isDriver) { + val driverActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), + name = DRIVER_AKKA_ACTOR_NAME) logInfo("Registered BlockManagerMaster Actor") - masterActor + driverActor } else { - val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME) + val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, DRIVER_AKKA_ACTOR_NAME) logInfo("Connecting to BlockManagerMaster: " + url) actorSystem.actorFor(url) } } - /** Remove a dead host from the master actor. This is only called on the master side. */ + /** Remove a dead host from the driver actor. This is only called on the driver side. */ def notifyADeadHost(host: String) { tell(RemoveHost(host)) logInfo("Removed " + host + " successfully in notifyADeadHost") } /** - * Send the master actor a heart beat from the slave. Returns true if everything works out, - * false if the master does not know about the given block manager, which means the block + * Send the driver actor a heart beat from the slave. Returns true if everything works out, + * false if the driver does not know about the given block manager, which means the block * manager should re-register. */ def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { - askMasterWithRetry[Boolean](HeartBeat(blockManagerId)) + askDriverWithReply[Boolean](HeartBeat(blockManagerId)) } - /** Register the BlockManager's id with the master. */ + /** Register the BlockManager's id with the driver. */ def registerBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { logInfo("Trying to register BlockManager") @@ -70,25 +69,25 @@ private[spark] class BlockManagerMaster( storageLevel: StorageLevel, memSize: Long, diskSize: Long): Boolean = { - val res = askMasterWithRetry[Boolean]( + val res = askDriverWithReply[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) logInfo("Updated info of block " + blockId) res } - /** Get locations of the blockId from the master */ + /** Get locations of the blockId from the driver */ def getLocations(blockId: String): Seq[BlockManagerId] = { - askMasterWithRetry[Seq[BlockManagerId]](GetLocations(blockId)) + askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) } - /** Get locations of multiple blockIds from the master */ + /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { - askMasterWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } - /** Get ids of other nodes in the cluster from the master */ + /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { - val result = askMasterWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) + val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) if (result.length != numPeers) { throw new SparkException( "Error getting peers, only got " + result.size + " instead of " + numPeers) @@ -98,10 +97,10 @@ private[spark] class BlockManagerMaster( /** * Remove a block from the slaves that have it. This can only be used to remove - * blocks that the master knows about. + * blocks that the driver knows about. */ def removeBlock(blockId: String) { - askMasterWithRetry(RemoveBlock(blockId)) + askDriverWithReply(RemoveBlock(blockId)) } /** @@ -111,33 +110,33 @@ private[spark] class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } - /** Stop the master actor, called only on the Spark master node */ + /** Stop the driver actor, called only on the Spark driver node */ def stop() { - if (masterActor != null) { + if (driverActor != null) { tell(StopBlockManagerMaster) - masterActor = null + driverActor = null logInfo("BlockManagerMaster stopped") } } /** Send a one-way message to the master actor, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!askMasterWithRetry[Boolean](message)) { + if (!askDriverWithReply[Boolean](message)) { throw new SparkException("BlockManagerMasterActor returned false, expected true.") } } /** - * Send a message to the master actor and get its result within a default timeout, or + * Send a message to the driver actor and get its result within a default timeout, or * throw a SparkException if this fails. */ - private def askMasterWithRetry[T](message: Any): T = { + private def askDriverWithReply[T](message: Any): T = { // TODO: Consider removing multiple attempts - if (masterActor == null) { - throw new SparkException("Error sending message to BlockManager as masterActor is null " + + if (driverActor == null) { + throw new SparkException("Error sending message to BlockManager as driverActor is null " + "[message = " + message + "]") } var attempts = 0 @@ -145,7 +144,7 @@ private[spark] class BlockManagerMaster( while (attempts < AKKA_RETRY_ATTEMPS) { attempts += 1 try { - val future = masterActor.ask(message)(timeout) + val future = driverActor.ask(message)(timeout) val result = Await.result(future, timeout) if (result == null) { throw new Exception("BlockManagerMaster returned null") diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala index 689f07b969..0b8f6d4303 100644 --- a/core/src/main/scala/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/spark/storage/ThreadingTest.scala @@ -75,9 +75,9 @@ private[spark] object ThreadingTest { System.setProperty("spark.kryoserializer.buffer.mb", "1") val actorSystem = ActorSystem("test") val serializer = new KryoSerializer - val masterIp: String = System.getProperty("spark.master.host", "localhost") - val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt - val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort) + val driverIp: String = System.getProperty("spark.driver.host", "localhost") + val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt + val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, driverIp, driverPort) val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 01351de4ae..42ce6f3c74 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -46,7 +46,7 @@ public class JavaAPISuite implements Serializable { sc.stop(); sc = null; // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port"); + System.clearProperty("spark.driver.port"); } static class ReverseIntComparator implements Comparator, Serializable { diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala index b5e31ddae3..ff00dd05dd 100644 --- a/core/src/test/scala/spark/LocalSparkContext.scala +++ b/core/src/test/scala/spark/LocalSparkContext.scala @@ -26,7 +26,7 @@ object LocalSparkContext { def stop(sc: SparkContext) { sc.stop() // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 7d5305f1e0..718107d2b5 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -79,7 +79,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("remote fetch") { val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) - System.setProperty("spark.master.port", boundPort.toString) + System.setProperty("spark.driver.port", boundPort.toString) val masterTracker = new MapOutputTracker(actorSystem, true) val slaveTracker = new MapOutputTracker(actorSystem, false) masterTracker.registerShuffle(10, 1) diff --git a/docs/configuration.md b/docs/configuration.md index 036a0df480..a7054b4321 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -202,7 +202,7 @@ Apart from these, the following properties are also available, and may be useful 10 Maximum message size to allow in "control plane" communication (for serialized tasks and task - results), in MB. Increase this if your tasks need to send back large results to the master + results), in MB. Increase this if your tasks need to send back large results to the driver (e.g. using collect() on a large dataset). @@ -211,7 +211,7 @@ Apart from these, the following properties are also available, and may be useful 4 Number of actor threads to use for communication. Can be useful to increase on large clusters - when the master has a lot of CPU cores. + when the driver has a lot of CPU cores. @@ -222,17 +222,17 @@ Apart from these, the following properties are also available, and may be useful - spark.master.host + spark.driver.host (local hostname) - Hostname or IP address for the master to listen on. + Hostname or IP address for the driver to listen on. - spark.master.port + spark.driver.port (random) - Port for the master to listen on. + Port for the driver to listen on. diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 46ab34f063..df7235756d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -26,7 +26,7 @@ class PySparkTestCase(unittest.TestCase): sys.path = self._old_sys_path # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown - self.sc.jvm.System.clearProperty("spark.master.port") + self.sc.jvm.System.clearProperty("spark.driver.port") class TestCheckpoint(PySparkTestCase): diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala index db78d06d4f..43559b96d3 100644 --- a/repl/src/test/scala/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/spark/repl/ReplSuite.scala @@ -31,7 +31,7 @@ class ReplSuite extends FunSuite { if (interp.sparkContext != null) interp.sparkContext.stop() // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") return out.toString } diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala index aa6be95f30..8c322dd698 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala @@ -153,8 +153,8 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log /** A helper actor that communicates with the NetworkInputTracker */ private class NetworkReceiverActor extends Actor { logInfo("Attempting to register with tracker") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt + val ip = System.getProperty("spark.driver.host", "localhost") + val port = System.getProperty("spark.driver.port", "7077").toInt val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) val tracker = env.actorSystem.actorFor(url) val timeout = 5.seconds diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index c84e7331c7..79d6093429 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -43,7 +43,7 @@ public class JavaAPISuite implements Serializable { ssc = null; // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port"); + System.clearProperty("spark.driver.port"); } @Test diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index bfdf32c73e..4a036f0710 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -10,7 +10,7 @@ class BasicOperationsSuite extends TestSuiteBase { after { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } test("map") { diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index d2f32c189b..563a7d1458 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -19,7 +19,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { FileUtils.deleteDirectory(new File(checkpointDir)) // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } var ssc: StreamingContext = null diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index 7493ac1207..c4cfffbfc1 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -24,7 +24,7 @@ class FailureSuite extends TestSuiteBase with BeforeAndAfter { FileUtils.deleteDirectory(new File(checkpointDir)) // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } override def framework = "CheckpointSuite" diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index d7ba7a5d17..70ae6e3934 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -42,7 +42,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } test("network input stream") { diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index 0c6e928835..cd9608df53 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -13,7 +13,7 @@ class WindowOperationsSuite extends TestSuiteBase { after { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } val largerSlideInput = Seq( -- cgit v1.2.3 From 539491bbc333834b9ae2721ae6cf3524cefb91ea Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 25 Jan 2013 09:29:59 -0800 Subject: code reformatting --- core/src/main/scala/spark/RDD.scala | 4 ++-- core/src/main/scala/spark/storage/BlockManagerUI.scala | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 870cc5ca78..4fcab9279a 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -94,7 +94,7 @@ abstract class RDD[T: ClassManifest]( /** How this RDD depends on any parent RDDs. */ protected def getDependencies(): List[Dependency[_]] = dependencies_ - // A friendly name for this RDD + /** A friendly name for this RDD */ var name: String = null /** Optionally overridden by subclasses to specify placement preferences. */ @@ -111,7 +111,7 @@ abstract class RDD[T: ClassManifest]( /** A unique ID for this RDD (within its SparkContext). */ val id = sc.newRddId() - /* Assign a name to this RDD */ + /** Assign a name to this RDD */ def setName(_name: String) = { name = _name this diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 35cbd59280..1003cc7a61 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -57,7 +57,8 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) - spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) + spark.storage.html.index. + render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) } }}} ~ get { path("rdd") { parameter("id") { id => { completeWith { @@ -67,9 +68,10 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray - val filteredStorageStatusList = StorageUtils.filterStorageStatusByPrefix(storageStatusList, prefix) + val filteredStorageStatusList = StorageUtils. + filterStorageStatusByPrefix(storageStatusList, prefix) - val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).first + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) -- cgit v1.2.3 From 1cadaa164e9f078e4ca483edb9db7fd5507c9e64 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 25 Jan 2013 09:30:21 -0800 Subject: switch to TimeStampedHashMap for storing persistent Rdds --- core/src/main/scala/spark/SparkContext.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d994648899..10ceeb3028 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -44,6 +44,7 @@ import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} +import util.TimeStampedHashMap /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -110,7 +111,7 @@ class SparkContext( private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new ConcurrentHashMap[Int, RDD[_]]() + private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() // Add each JAR given through the constructor jars.foreach { addJar(_) } -- cgit v1.2.3 From a1d9d1767d821c1e25e485e32d9356b12aba6a01 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 25 Jan 2013 10:05:26 -0800 Subject: fixup 1cadaa1, changed api of map --- core/src/main/scala/spark/storage/StorageUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index 63ad5c125b..a10e3a95c6 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -56,8 +56,8 @@ object StorageUtils { // Find the id of the RDD, e.g. rdd_1 => 1 val rddId = rddKey.split("_").last.toInt // Get the friendly name for the rdd, if available. - val rddName = Option(sc.persistentRdds.get(rddId).name).getOrElse(rddKey) - val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel + val rddName = Option(sc.persistentRdds(rddId).name).getOrElse(rddKey) + val rddStorageLevel = sc.persistentRdds(rddId).getStorageLevel RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize) }.toArray -- cgit v1.2.3 From 8efbda0b179e3821a1221c6d78681fc74248cdac Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 25 Jan 2013 14:55:33 -0600 Subject: Call executeOnCompleteCallbacks in more finally blocks. --- .../main/scala/spark/scheduler/DAGScheduler.scala | 13 +++--- .../scala/spark/scheduler/ShuffleMapTask.scala | 46 +++++++++++----------- 2 files changed, 30 insertions(+), 29 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b320be8863..f599eb00bd 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -40,7 +40,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with eventQueue.put(HostLost(host)) } - // Called by TaskScheduler to cancel an entier TaskSet due to repeated failures. + // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. override def taskSetFailed(taskSet: TaskSet, reason: String) { eventQueue.put(TaskSetFailed(taskSet, reason)) } @@ -54,8 +54,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // resubmit failed stages val POLL_TIMEOUT = 10L - private val lock = new Object // Used for access to the entire DAGScheduler - private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] val nextRunId = new AtomicInteger(0) @@ -337,9 +335,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val rdd = job.finalStage.rdd val split = rdd.splits(job.partitions(0)) val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) - val result = job.func(taskContext, rdd.iterator(split, taskContext)) - taskContext.executeOnCompleteCallbacks() - job.listener.taskSucceeded(0, result) + try { + val result = job.func(taskContext, rdd.iterator(split, taskContext)) + job.listener.taskSucceeded(0, result) + } finally { + taskContext.executeOnCompleteCallbacks() + } } catch { case e: Exception => job.listener.jobFailed(e) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 19f5328eee..83641a2a84 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -81,7 +81,7 @@ private[spark] class ShuffleMapTask( with Externalizable with Logging { - def this() = this(0, null, null, 0, null) + protected def this() = this(0, null, null, 0, null) var split = if (rdd == null) { null @@ -117,34 +117,34 @@ private[spark] class ShuffleMapTask( override def run(attemptId: Long): MapStatus = { val numOutputSplits = dep.partitioner.numPartitions - val partitioner = dep.partitioner val taskContext = new TaskContext(stageId, partition, attemptId) + try { + // Partition the map output. + val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) + for (elem <- rdd.iterator(split, taskContext)) { + val pair = elem.asInstanceOf[(Any, Any)] + val bucketId = dep.partitioner.getPartition(pair._1) + buckets(bucketId) += pair + } + val bucketIterators = buckets.map(_.iterator) - // Partition the map output. - val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) - for (elem <- rdd.iterator(split, taskContext)) { - val pair = elem.asInstanceOf[(Any, Any)] - val bucketId = partitioner.getPartition(pair._1) - buckets(bucketId) += pair - } - val bucketIterators = buckets.map(_.iterator) + val compressedSizes = new Array[Byte](numOutputSplits) - val compressedSizes = new Array[Byte](numOutputSplits) + val blockManager = SparkEnv.get.blockManager + for (i <- 0 until numOutputSplits) { + val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i + // Get a Scala iterator from Java map + val iter: Iterator[(Any, Any)] = bucketIterators(i) + val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) + compressedSizes(i) = MapOutputTracker.compressSize(size) + } - val blockManager = SparkEnv.get.blockManager - for (i <- 0 until numOutputSplits) { - val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i - // Get a Scala iterator from Java map - val iter: Iterator[(Any, Any)] = bucketIterators(i) - val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) - compressedSizes(i) = MapOutputTracker.compressSize(size) + return new MapStatus(blockManager.blockManagerId, compressedSizes) + } finally { + // Execute the callbacks on task completion. + taskContext.executeOnCompleteCallbacks() } - - // Execute the callbacks on task completion. - taskContext.executeOnCompleteCallbacks() - - return new MapStatus(blockManager.blockManagerId, compressedSizes) } override def preferredLocations: Seq[String] = locs -- cgit v1.2.3 From 49c05608f5f27354da120e2367b6d4a63ec38948 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 25 Jan 2013 17:04:16 -0800 Subject: add metadatacleaner for persisentRdd map --- core/src/main/scala/spark/SparkContext.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 10ceeb3028..bff54dbdd1 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -44,7 +44,7 @@ import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import util.TimeStampedHashMap +import util.{MetadataCleaner, TimeStampedHashMap} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -113,6 +113,9 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() + private[spark] val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) + + // Add each JAR given through the constructor jars.foreach { addJar(_) } @@ -512,6 +515,7 @@ class SparkContext( /** Shut down the SparkContext. */ def stop() { if (dagScheduler != null) { + metadataCleaner.cancel() dagScheduler.stop() dagScheduler = null taskScheduler = null @@ -654,6 +658,12 @@ class SparkContext( /** Register a new RDD, returning its RDD ID */ private[spark] def newRddId(): Int = nextRddId.getAndIncrement() + + private[spark] def cleanup(cleanupTime: Long) { + var sizeBefore = persistentRdds.size + persistentRdds.clearOldValues(cleanupTime) + logInfo("idToStage " + sizeBefore + " --> " + persistentRdds.size) + } } /** -- cgit v1.2.3 From d49cf0e587b7cbbd31917d9bb69f98466feb0f9f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 26 Jan 2013 15:57:01 -0800 Subject: Fix JavaRDDLike.flatMap(PairFlatMapFunction) (SPARK-668). This workaround is easier than rewriting JavaRDDLike in Java. --- .../main/scala/spark/api/java/JavaRDDLike.scala | 7 +++--- .../spark/api/java/PairFlatMapWorkaround.java | 20 ++++++++++++++++ core/src/test/scala/spark/JavaAPISuite.java | 28 ++++++++++++++++++++++ 3 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index b3698ffa44..4c95c989b5 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -12,7 +12,7 @@ import spark.storage.StorageLevel import com.google.common.base.Optional -trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { +trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] { def wrapRDD(rdd: RDD[T]): This implicit val classManifest: ClassManifest[T] @@ -82,10 +82,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. + * Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java. */ - def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = { + private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = { import scala.collection.JavaConverters._ def fn = (x: T) => f.apply(x).asScala def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] diff --git a/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java new file mode 100644 index 0000000000..68b6fd6622 --- /dev/null +++ b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java @@ -0,0 +1,20 @@ +package spark.api.java; + +import spark.api.java.JavaPairRDD; +import spark.api.java.JavaRDDLike; +import spark.api.java.function.PairFlatMapFunction; + +import java.io.Serializable; + +/** + * Workaround for SPARK-668. + */ +class PairFlatMapWorkaround implements Serializable { + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ + public JavaPairRDD flatMap(PairFlatMapFunction f) { + return ((JavaRDDLike ) this).doFlatMap(f); + } +} diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 01351de4ae..f50ba093e9 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -355,6 +355,34 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(11, pairs.count()); } + @Test + public void mapsFromPairsToPairs() { + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD swapped = pairRDD.flatMap( + new PairFlatMapFunction, String, Integer>() { + @Override + public Iterable> call(Tuple2 item) throws Exception { + return Collections.singletonList(item.swap()); + } + }); + swapped.collect(); + + // There was never a bug here, but it's worth testing: + pairRDD.map(new PairFunction, String, Integer>() { + @Override + public Tuple2 call(Tuple2 item) throws Exception { + return item.swap(); + } + }).collect(); + } + @Test public void mapPartitions() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); -- cgit v1.2.3 From ad4232b4dadc6290d3c4696d3cc007d3f01cb236 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Sat, 26 Jan 2013 18:07:14 -0800 Subject: Fix deadlock in BlockManager reregistration triggered by failed updates. --- .../main/scala/spark/storage/BlockManager.scala | 35 +++++++++++++++++-- .../scala/spark/storage/BlockManagerSuite.scala | 40 +++++++++++++++++++++- 2 files changed, 72 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 19cdaaa984..19d35b8667 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -90,7 +90,10 @@ class BlockManager( val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) - @volatile private var shuttingDown = false + // Pending reregistration action being executed asynchronously or null if none + // is pending. Accesses should synchronize on asyncReregisterLock. + var asyncReregisterTask: Future[Unit] = null + val asyncReregisterLock = new Object private def heartBeat() { if (!master.sendHeartBeat(blockManagerId)) { @@ -147,6 +150,8 @@ class BlockManager( /** * Reregister with the master and report all blocks to it. This will be called by the heart beat * thread if our heartbeat to the block amnager indicates that we were not registered. + * + * Note that this method must be called without any BlockInfo locks held. */ def reregister() { // TODO: We might need to rate limit reregistering. @@ -155,6 +160,32 @@ class BlockManager( reportAllBlocks() } + /** + * Reregister with the master sometime soon. + */ + def asyncReregister() { + asyncReregisterLock.synchronized { + if (asyncReregisterTask == null) { + asyncReregisterTask = Future[Unit] { + reregister() + asyncReregisterLock.synchronized { + asyncReregisterTask = null + } + } + } + } + } + + /** + * For testing. Wait for any pending asynchronous reregistration; otherwise, do nothing. + */ + def waitForAsyncReregister() { + val task = asyncReregisterTask + if (task != null) { + Await.ready(task, Duration.Inf) + } + } + /** * Get storage level of local block. If no info exists for the block, then returns null. */ @@ -170,7 +201,7 @@ class BlockManager( if (needReregister) { logInfo("Got told to reregister updating block " + blockId) // Reregistering will report our new block for free. - reregister() + asyncReregister() } logDebug("Told master about block " + blockId) } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index a1aeb12f25..2165744689 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -219,18 +219,56 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val a2 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(master.getLocations("a1").size > 0, "master was not told about a1") master.notifyADeadHost(store.blockManagerId.ip) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY) + store.waitForAsyncReregister() assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") assert(master.getLocations("a2").size > 0, "master was not told about a2") } + test("reregistration doesn't dead lock") { + val heartBeat = PrivateMethod[Unit]('heartBeat) + store = new BlockManager(actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = List(new Array[Byte](400)) + + // try many times to trigger any deadlocks + for (i <- 1 to 100) { + master.notifyADeadHost(store.blockManagerId.ip) + val t1 = new Thread { + override def run = { + store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true) + } + } + val t2 = new Thread { + override def run = { + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + } + } + val t3 = new Thread { + override def run = { + store invokePrivate heartBeat() + } + } + + t1.start + t2.start + t3.start + t1.join + t2.join + t3.join + + store.dropFromMemory("a1", null) + store.dropFromMemory("a2", null) + store.waitForAsyncReregister() + } + } + test("in-memory LRU storage") { store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) -- cgit v1.2.3 From 58fc6b2bed9f660fbf134aab188827b7d8975a62 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Sat, 26 Jan 2013 18:07:53 -0800 Subject: Handle duplicate registrations better. --- core/src/main/scala/spark/storage/BlockManagerMasterActor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index f4d026da33..2216c33b76 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -183,7 +183,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { if (blockManagerId.ip == Utils.localHostName() && !isLocal) { logInfo("Got Register Msg from master node, don't register it") - } else { + } else if (!blockManagerInfo.contains(blockManagerId)) { blockManagerIdByHost.get(blockManagerId.ip) match { case Some(managers) => // A block manager of the same host name already exists. -- cgit v1.2.3 From 44b4a0f88fcb31727347b755ae8ec14d69571b52 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 27 Jan 2013 19:23:49 -0800 Subject: Track workers by executor ID instead of hostname to allow multiple executors per machine and remove the need for multiple IP addresses in unit tests. --- core/src/main/scala/spark/MapOutputTracker.scala | 4 +- core/src/main/scala/spark/SparkContext.scala | 6 +- core/src/main/scala/spark/SparkEnv.scala | 9 +- .../scala/spark/deploy/LocalSparkCluster.scala | 16 +-- .../main/scala/spark/deploy/master/Master.scala | 4 +- .../scala/spark/deploy/worker/ExecutorRunner.scala | 2 +- core/src/main/scala/spark/executor/Executor.scala | 4 +- .../spark/executor/MesosExecutorBackend.scala | 3 +- .../spark/executor/StandaloneExecutorBackend.scala | 14 +-- .../main/scala/spark/scheduler/DAGScheduler.scala | 44 ++++----- .../scala/spark/scheduler/DAGSchedulerEvent.scala | 2 +- .../src/main/scala/spark/scheduler/MapStatus.scala | 6 +- core/src/main/scala/spark/scheduler/Stage.scala | 11 ++- .../spark/scheduler/TaskSchedulerListener.scala | 2 +- .../spark/scheduler/cluster/ClusterScheduler.scala | 110 +++++++++++---------- .../cluster/SparkDeploySchedulerBackend.scala | 4 +- .../cluster/StandaloneSchedulerBackend.scala | 64 ++++++------ .../spark/scheduler/cluster/TaskDescription.scala | 2 +- .../scala/spark/scheduler/cluster/TaskInfo.scala | 7 +- .../spark/scheduler/cluster/TaskSetManager.scala | 38 +++---- .../spark/scheduler/cluster/WorkerOffer.scala | 4 +- .../scheduler/mesos/MesosSchedulerBackend.scala | 2 +- .../main/scala/spark/storage/BlockManager.scala | 10 +- .../main/scala/spark/storage/BlockManagerId.scala | 27 +++-- .../scala/spark/storage/BlockManagerMaster.scala | 12 +-- .../spark/storage/BlockManagerMasterActor.scala | 66 +++++-------- .../scala/spark/storage/BlockManagerMessages.scala | 2 +- .../main/scala/spark/storage/BlockManagerUI.scala | 7 +- .../main/scala/spark/storage/ThreadingTest.scala | 3 +- core/src/main/scala/spark/util/AkkaUtils.scala | 6 +- .../main/scala/spark/util/TimeStampedHashMap.scala | 4 +- core/src/test/scala/spark/DriverSuite.scala | 5 +- .../test/scala/spark/MapOutputTrackerSuite.scala | 69 +++++++------ .../scala/spark/storage/BlockManagerSuite.scala | 86 ++++++++-------- sbt/sbt | 2 +- 35 files changed, 343 insertions(+), 314 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index ac02f3363a..c1f012b419 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -114,7 +114,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea var array = mapStatuses(shuffleId) if (array != null) { array.synchronized { - if (array(mapId) != null && array(mapId).address == bmAddress) { + if (array(mapId) != null && array(mapId).location == bmAddress) { array(mapId) = null } } @@ -277,7 +277,7 @@ private[spark] object MapOutputTracker { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing an output location for shuffle " + shuffleId)) } else { - (status.address, decompressSize(status.compressedSizes(reduceId))) + (status.location, decompressSize(status.compressedSizes(reduceId))) } } } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4581c0adcf..39721b47ae 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -80,6 +80,7 @@ class SparkContext( // Create the Spark execution environment (cache, map output tracker, etc) private[spark] val env = SparkEnv.createFromSystemProperties( + "", System.getProperty("spark.master.host"), System.getProperty("spark.master.port").toInt, true, @@ -97,7 +98,7 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() - private[spark] val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) + private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) // Add each JAR given through the constructor @@ -649,10 +650,9 @@ class SparkContext( /** Register a new RDD, returning its RDD ID */ private[spark] def newRddId(): Int = nextRddId.getAndIncrement() + /** Called by MetadataCleaner to clean up the persistentRdds map periodically */ private[spark] def cleanup(cleanupTime: Long) { - var sizeBefore = persistentRdds.size persistentRdds.clearOldValues(cleanupTime) - logInfo("idToStage " + sizeBefore + " --> " + persistentRdds.size) } } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 2a7a8af83d..0c094edcf3 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -19,6 +19,7 @@ import spark.util.AkkaUtils * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set. */ class SparkEnv ( + val executorId: String, val actorSystem: ActorSystem, val serializer: Serializer, val closureSerializer: Serializer, @@ -58,11 +59,12 @@ object SparkEnv extends Logging { } def createFromSystemProperties( + executorId: String, hostname: String, port: Int, isMaster: Boolean, - isLocal: Boolean - ) : SparkEnv = { + isLocal: Boolean): SparkEnv = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port) // Bit of a hack: If this is the master and our port was 0 (meaning bind to any free port), @@ -86,7 +88,7 @@ object SparkEnv extends Logging { val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt val blockManagerMaster = new BlockManagerMaster( actorSystem, isMaster, isLocal, masterIp, masterPort) - val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer) + val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer) val connectionManager = blockManager.connectionManager @@ -122,6 +124,7 @@ object SparkEnv extends Logging { } new SparkEnv( + executorId, actorSystem, serializer, closureSerializer, diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index 4211d80596..8f51051e39 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -9,6 +9,12 @@ import spark.{Logging, Utils} import scala.collection.mutable.ArrayBuffer +/** + * Testing class that creates a Spark standalone process in-cluster (that is, running the + * spark.deploy.master.Master and spark.deploy.worker.Workers in the same JVMs). Executors launched + * by the Workers still run in separate JVMs. This can be used to test distributed operation and + * fault recovery without spinning up a lot of processes. + */ private[spark] class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging { @@ -35,16 +41,12 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) /* Start the Slaves */ for (slaveNum <- 1 to numSlaves) { - /* We can pretend to test distributed stuff by giving the slaves distinct hostnames. - All of 127/8 should be a loopback, we use 127.100.*.* in hopes that it is - sufficiently distinctive. */ - val slaveIpAddress = "127.100.0." + (slaveNum % 256) val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("sparkWorker" + slaveNum, slaveIpAddress, 0) + AkkaUtils.createActorSystem("sparkWorker" + slaveNum, localIpAddress, 0) slaveActorSystems += actorSystem val actor = actorSystem.actorOf( - Props(new Worker(slaveIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)), - name = "Worker") + Props(new Worker(localIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)), + name = "Worker") slaveActors += actor } diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 2c2cd0231b..2e7e868579 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -97,10 +97,10 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor exec.worker.removeExecutor(exec) // Only retry certain number of times so we don't go into an infinite loop. - if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) { + if (jobInfo.incrementRetryCount < JobState.MAX_NUM_RETRY) { schedule() } else { - val e = new SparkException("Job %s wth ID %s failed %d times.".format( + val e = new SparkException("Job %s with ID %s failed %d times.".format( jobInfo.desc.name, jobInfo.id, jobInfo.retryCount)) logError(e.getMessage, e) throw e diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 0d1fe2a6b4..af3acfecb6 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -67,7 +67,7 @@ private[spark] class ExecutorRunner( /** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { - case "{{SLAVEID}}" => workerId + case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => hostname case "{{CORES}}" => cores.toString case other => other diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 28d9d40d43..bd21ba719a 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -30,7 +30,7 @@ private[spark] class Executor extends Logging { initLogging() - def initialize(slaveHostname: String, properties: Seq[(String, String)]) { + def initialize(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) { // Make sure the local hostname we report matches the cluster scheduler's name for this host Utils.setCustomHostname(slaveHostname) @@ -64,7 +64,7 @@ private[spark] class Executor extends Logging { ) // Initialize Spark environment (using system properties read above) - env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false) + env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) SparkEnv.set(env) // Start worker thread pool diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala index eeab3959c6..1ef88075ad 100644 --- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala @@ -29,9 +29,10 @@ private[spark] class MesosExecutorBackend(executor: Executor) executorInfo: ExecutorInfo, frameworkInfo: FrameworkInfo, slaveInfo: SlaveInfo) { + logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) this.driver = driver val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) - executor.initialize(slaveInfo.getHostname, properties) + executor.initialize(executorInfo.getExecutorId.getValue, slaveInfo.getHostname, properties) } override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index a29bf974d2..435ee5743e 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -17,7 +17,7 @@ import spark.scheduler.cluster.RegisterSlave private[spark] class StandaloneExecutorBackend( executor: Executor, masterUrl: String, - slaveId: String, + executorId: String, hostname: String, cores: Int) extends Actor @@ -30,7 +30,7 @@ private[spark] class StandaloneExecutorBackend( try { logInfo("Connecting to master: " + masterUrl) master = context.actorFor(masterUrl) - master ! RegisterSlave(slaveId, hostname, cores) + master ! RegisterSlave(executorId, hostname, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing } catch { @@ -43,7 +43,7 @@ private[spark] class StandaloneExecutorBackend( override def receive = { case RegisteredSlave(sparkProperties) => logInfo("Successfully registered with master") - executor.initialize(hostname, sparkProperties) + executor.initialize(executorId, hostname, sparkProperties) case RegisterSlaveFailed(message) => logError("Slave registration failed: " + message) @@ -55,24 +55,24 @@ private[spark] class StandaloneExecutorBackend( } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - master ! StatusUpdate(slaveId, taskId, state, data) + master ! StatusUpdate(executorId, taskId, state, data) } } private[spark] object StandaloneExecutorBackend { - def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) { + def run(masterUrl: String, executorId: String, hostname: String, cores: Int) { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(new Executor, masterUrl, slaveId, hostname, cores)), + Props(new StandaloneExecutorBackend(new Executor, masterUrl, executorId, hostname, cores)), name = "Executor") actorSystem.awaitTermination() } def main(args: Array[String]) { if (args.length != 4) { - System.err.println("Usage: StandaloneExecutorBackend ") + System.err.println("Usage: StandaloneExecutorBackend ") System.exit(1) } run(args(0), args(1), args(2), args(3).toInt) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index f599eb00bd..bd541d4207 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -35,9 +35,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with eventQueue.put(CompletionEvent(task, reason, result, accumUpdates)) } - // Called by TaskScheduler when a host fails. - override def hostLost(host: String) { - eventQueue.put(HostLost(host)) + // Called by TaskScheduler when an executor fails. + override def executorLost(execId: String) { + eventQueue.put(ExecutorLost(execId)) } // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. @@ -72,7 +72,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // For tracking failed nodes, we use the MapOutputTracker's generation number, which is // sent with every task. When we detect a node failing, we note the current generation number - // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask + // and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask // results. // TODO: Garbage collect information about failure generations when we know there are no more // stray messages to detect. @@ -108,7 +108,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } def clearCacheLocs() { - cacheLocs.clear + cacheLocs.clear() } /** @@ -271,8 +271,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with submitStage(finalStage) } - case HostLost(host) => - handleHostLost(host) + case ExecutorLost(execId) => + handleExecutorLost(execId) case completion: CompletionEvent => handleTaskCompletion(completion) @@ -436,10 +436,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with case smt: ShuffleMapTask => val stage = idToStage(smt.stageId) val status = event.result.asInstanceOf[MapStatus] - val host = status.address.ip - logInfo("ShuffleMapTask finished with host " + host) - if (failedGeneration.contains(host) && smt.generation <= failedGeneration(host)) { - logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + host) + val execId = status.location.executorId + logDebug("ShuffleMapTask finished on " + execId) + if (failedGeneration.contains(execId) && smt.generation <= failedGeneration(execId)) { + logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) } else { stage.addOutputLoc(smt.partition, status) } @@ -511,9 +511,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // Remember that a fetch failed now; this is used to resubmit the broken // stages later, after a small wait (to give other tasks the chance to fail) lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock - // TODO: mark the host as failed only if there were lots of fetch failures on it + // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleHostLost(bmAddress.ip, Some(task.generation)) + handleExecutorLost(bmAddress.executorId, Some(task.generation)) } case other => @@ -523,21 +523,21 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } /** - * Responds to a host being lost. This is called inside the event loop so it assumes that it can - * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside. + * Responds to an executor being lost. This is called inside the event loop, so it assumes it can + * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. * * Optionally the generation during which the failure was caught can be passed to avoid allowing * stray fetch failures from possibly retriggering the detection of a node as lost. */ - def handleHostLost(host: String, maybeGeneration: Option[Long] = None) { + def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) { val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) - if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) { - failedGeneration(host) = currentGeneration - logInfo("Host lost: " + host + " (generation " + currentGeneration + ")") - env.blockManager.master.notifyADeadHost(host) + if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) { + failedGeneration(execId) = currentGeneration + logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration)) + env.blockManager.master.removeExecutor(execId) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { - stage.removeOutputsOnHost(host) + stage.removeOutputsOnExecutor(execId) val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray mapOutputTracker.registerMapOutputs(shuffleId, locs, true) } @@ -546,7 +546,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } clearCacheLocs() } else { - logDebug("Additional host lost message for " + host + + logDebug("Additional executor lost message for " + execId + "(generation " + currentGeneration + ")") } } diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index 3422a21d9d..b34fa78c07 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -28,7 +28,7 @@ private[spark] case class CompletionEvent( accumUpdates: Map[Long, Any]) extends DAGSchedulerEvent -private[spark] case class HostLost(host: String) extends DAGSchedulerEvent +private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala index fae643f3a8..203abb917b 100644 --- a/core/src/main/scala/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/spark/scheduler/MapStatus.scala @@ -8,19 +8,19 @@ import java.io.{ObjectOutput, ObjectInput, Externalizable} * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. * The map output sizes are compressed using MapOutputTracker.compressSize. */ -private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: Array[Byte]) +private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte]) extends Externalizable { def this() = this(null, null) // For deserialization only def writeExternal(out: ObjectOutput) { - address.writeExternal(out) + location.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) } def readExternal(in: ObjectInput) { - address = BlockManagerId(in) + location = BlockManagerId(in) compressedSizes = new Array[Byte](in.readInt()) in.readFully(compressedSizes) } diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index 4846b66729..e9419728e3 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -51,18 +51,18 @@ private[spark] class Stage( def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.address == bmAddress) + val newList = prevList.filterNot(_.location == bmAddress) outputLocs(partition) = newList if (prevList != Nil && newList == Nil) { numAvailableOutputs -= 1 } } - def removeOutputsOnHost(host: String) { + def removeOutputsOnExecutor(execId: String) { var becameUnavailable = false for (partition <- 0 until numPartitions) { val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.address.ip == host) + val newList = prevList.filterNot(_.location.executorId == execId) outputLocs(partition) = newList if (prevList != Nil && newList == Nil) { becameUnavailable = true @@ -70,7 +70,8 @@ private[spark] class Stage( } } if (becameUnavailable) { - logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable)) + logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( + this, execId, numAvailableOutputs, numPartitions, isAvailable)) } } @@ -82,7 +83,7 @@ private[spark] class Stage( def origin: String = rdd.origin - override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]" + override def toString = "Stage " + id override def hashCode(): Int = id } diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala index fa4de15d0d..9fcef86e46 100644 --- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala +++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala @@ -12,7 +12,7 @@ private[spark] trait TaskSchedulerListener { def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit // A node was lost from the cluster. - def hostLost(host: String): Unit + def executorLost(execId: String): Unit // The TaskScheduler wants to abort an entire task set. def taskSetFailed(taskSet: TaskSet, reason: String): Unit diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index a639b72795..0b4177805b 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -27,19 +27,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext) var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] val taskIdToTaskSetId = new HashMap[Long, String] - val taskIdToSlaveId = new HashMap[Long, String] + val taskIdToExecutorId = new HashMap[Long, String] val taskSetTaskIds = new HashMap[String, HashSet[Long]] // Incrementing Mesos task IDs val nextTaskId = new AtomicLong(0) - // Which hosts in the cluster are alive (contains hostnames) - val hostsAlive = new HashSet[String] + // Which executor IDs we have executors on + val activeExecutorIds = new HashSet[String] - // Which slave IDs we have executors on - val slaveIdsWithExecutors = new HashSet[String] + // The set of executors we have on each host; this is used to compute hostsAlive, which + // in turn is used to decide when we can attain data locality on a given host + val executorsByHost = new HashMap[String, HashSet[String]] - val slaveIdToHost = new HashMap[String, String] + val executorIdToHost = new HashMap[String, String] // JAR server, if any JARs were added by the user to the SparkContext var jarServer: HttpServer = null @@ -102,7 +103,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) activeTaskSets -= manager.taskSet.id activeTaskSetsQueue -= manager taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id) + taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) taskSetTaskIds.remove(manager.taskSet.id) } } @@ -117,8 +118,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { - slaveIdToHost(o.slaveId) = o.hostname - hostsAlive += o.hostname + executorIdToHost(o.executorId) = o.hostname } // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) @@ -128,16 +128,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext) do { launchedTask = false for (i <- 0 until offers.size) { - val sid = offers(i).slaveId + val execId = offers(i).executorId val host = offers(i).hostname - manager.slaveOffer(sid, host, availableCpus(i)) match { + manager.slaveOffer(execId, host, availableCpus(i)) match { case Some(task) => tasks(i) += task val tid = task.taskId taskIdToTaskSetId(tid) = manager.taskSet.id taskSetTaskIds(manager.taskSet.id) += tid - taskIdToSlaveId(tid) = sid - slaveIdsWithExecutors += sid + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + if (!executorsByHost.contains(host)) { + executorsByHost(host) = new HashSet() + } + executorsByHost(host) += execId availableCpus(i) -= 1 launchedTask = true @@ -152,25 +156,21 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { var taskSetToUpdate: Option[TaskSetManager] = None - var failedHost: Option[String] = None + var failedExecutor: Option[String] = None var taskFailed = false synchronized { try { - if (state == TaskState.LOST && taskIdToSlaveId.contains(tid)) { - // We lost the executor on this slave, so remember that it's gone - val slaveId = taskIdToSlaveId(tid) - val host = slaveIdToHost(slaveId) - if (hostsAlive.contains(host)) { - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - activeTaskSetsQueue.foreach(_.hostLost(host)) - failedHost = Some(host) + if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { + // We lost this entire executor, so remember that it's gone + val execId = taskIdToExecutorId(tid) + if (activeExecutorIds.contains(execId)) { + removeExecutor(execId) + failedExecutor = Some(execId) } } taskIdToTaskSetId.get(tid) match { case Some(taskSetId) => if (activeTaskSets.contains(taskSetId)) { - //activeTaskSets(taskSetId).statusUpdate(status) taskSetToUpdate = Some(activeTaskSets(taskSetId)) } if (TaskState.isFinished(state)) { @@ -178,7 +178,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (taskSetTaskIds.contains(taskSetId)) { taskSetTaskIds(taskSetId) -= tid } - taskIdToSlaveId.remove(tid) + taskIdToExecutorId.remove(tid) } if (state == TaskState.FAILED) { taskFailed = true @@ -190,12 +190,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext) case e: Exception => logError("Exception in statusUpdate", e) } } - // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock + // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock if (taskSetToUpdate != None) { taskSetToUpdate.get.statusUpdate(tid, state, serializedData) } - if (failedHost != None) { - listener.hostLost(failedHost.get) + if (failedExecutor != None) { + listener.executorLost(failedExecutor.get) backend.reviveOffers() } if (taskFailed) { @@ -249,32 +249,42 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - def slaveLost(slaveId: String, reason: ExecutorLossReason) { - var failedHost: Option[String] = None + def executorLost(executorId: String, reason: ExecutorLossReason) { + var failedExecutor: Option[String] = None synchronized { - slaveIdToHost.get(slaveId) match { - case Some(host) => - if (hostsAlive.contains(host)) { - logError("Lost an executor on " + host + ": " + reason) - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - activeTaskSetsQueue.foreach(_.hostLost(host)) - failedHost = Some(host) - } else { - // We may get multiple slaveLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor on " + host + " (already removed): " + reason) - } - case None => - // We were told about a slave being lost before we could even allocate work to it - logError("Lost slave " + slaveId + " (no work assigned yet)") + if (activeExecutorIds.contains(executorId)) { + val host = executorIdToHost(executorId) + logError("Lost executor %s on %s: %s".format(executorId, host, reason)) + removeExecutor(executorId) + failedExecutor = Some(executorId) + } else { + // We may get multiple executorLost() calls with different loss reasons. For example, one + // may be triggered by a dropped connection from the slave while another may be a report + // of executor termination from Mesos. We produce log messages for both so we eventually + // report the termination reason. + logError("Lost an executor " + executorId + " (already removed): " + reason) } } - if (failedHost != None) { - listener.hostLost(failedHost.get) + // Call listener.executorLost without holding the lock on this to prevent deadlock + if (failedExecutor != None) { + listener.executorLost(failedExecutor.get) backend.reviveOffers() } } + + /** Get a list of hosts that currently have executors */ + def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet + + /** Remove an executor from all our data structures and mark it as lost */ + private def removeExecutor(executorId: String) { + activeExecutorIds -= executorId + val host = executorIdToHost(executorId) + val execs = executorsByHost.getOrElse(host, new HashSet) + execs -= executorId + if (execs.isEmpty) { + executorsByHost -= host + } + executorIdToHost -= executorId + activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 4f82cd96dd..f0792c1b76 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -37,7 +37,7 @@ private[spark] class SparkDeploySchedulerBackend( val masterUrl = "akka://spark@%s:%s/user/%s".format( System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), StandaloneSchedulerBackend.ACTOR_NAME) - val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") + val args = Seq(masterUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome) @@ -81,7 +81,7 @@ private[spark] class SparkDeploySchedulerBackend( executorIdToSlaveId.get(id) match { case Some(slaveId) => executorIdToSlaveId.remove(id) - scheduler.slaveLost(slaveId, reason) + scheduler.executorLost(slaveId, reason) case None => logInfo("No slave ID known for executor %s".format(id)) } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index eeaae23dc8..32be1e7a26 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -28,8 +28,8 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor val slaveAddress = new HashMap[String, Address] val slaveHost = new HashMap[String, String] val freeCores = new HashMap[String, Int] - val actorToSlaveId = new HashMap[ActorRef, String] - val addressToSlaveId = new HashMap[Address, String] + val actorToExecutorId = new HashMap[ActorRef, String] + val addressToExecutorId = new HashMap[Address, String] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -37,28 +37,28 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } def receive = { - case RegisterSlave(slaveId, host, cores) => - if (slaveActor.contains(slaveId)) { - sender ! RegisterSlaveFailed("Duplicate slave ID: " + slaveId) + case RegisterSlave(executorId, host, cores) => + if (slaveActor.contains(executorId)) { + sender ! RegisterSlaveFailed("Duplicate executor ID: " + executorId) } else { - logInfo("Registered slave: " + sender + " with ID " + slaveId) + logInfo("Registered executor: " + sender + " with ID " + executorId) sender ! RegisteredSlave(sparkProperties) context.watch(sender) - slaveActor(slaveId) = sender - slaveHost(slaveId) = host - freeCores(slaveId) = cores - slaveAddress(slaveId) = sender.path.address - actorToSlaveId(sender) = slaveId - addressToSlaveId(sender.path.address) = slaveId + slaveActor(executorId) = sender + slaveHost(executorId) = host + freeCores(executorId) = cores + slaveAddress(executorId) = sender.path.address + actorToExecutorId(sender) = executorId + addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) makeOffers() } - case StatusUpdate(slaveId, taskId, state, data) => + case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { - freeCores(slaveId) += 1 - makeOffers(slaveId) + freeCores(executorId) += 1 + makeOffers(executorId) } case ReviveOffers => @@ -69,13 +69,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor context.stop(self) case Terminated(actor) => - actorToSlaveId.get(actor).foreach(removeSlave(_, "Akka actor terminated")) + actorToExecutorId.get(actor).foreach(removeSlave(_, "Akka actor terminated")) case RemoteClientDisconnected(transport, address) => - addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client disconnected")) + addressToExecutorId.get(address).foreach(removeSlave(_, "remote Akka client disconnected")) case RemoteClientShutdown(transport, address) => - addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client shutdown")) + addressToExecutorId.get(address).foreach(removeSlave(_, "remote Akka client shutdown")) } // Make fake resource offers on all slaves @@ -85,31 +85,31 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } // Make fake resource offers on just one slave - def makeOffers(slaveId: String) { + def makeOffers(executorId: String) { launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(slaveId, slaveHost(slaveId), freeCores(slaveId))))) + Seq(new WorkerOffer(executorId, slaveHost(executorId), freeCores(executorId))))) } // Launch tasks returned by a set of resource offers def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { - freeCores(task.slaveId) -= 1 - slaveActor(task.slaveId) ! LaunchTask(task) + freeCores(task.executorId) -= 1 + slaveActor(task.executorId) ! LaunchTask(task) } } // Remove a disconnected slave from the cluster - def removeSlave(slaveId: String, reason: String) { - logInfo("Slave " + slaveId + " disconnected, so removing it") - val numCores = freeCores(slaveId) - actorToSlaveId -= slaveActor(slaveId) - addressToSlaveId -= slaveAddress(slaveId) - slaveActor -= slaveId - slaveHost -= slaveId - freeCores -= slaveId - slaveHost -= slaveId + def removeSlave(executorId: String, reason: String) { + logInfo("Slave " + executorId + " disconnected, so removing it") + val numCores = freeCores(executorId) + actorToExecutorId -= slaveActor(executorId) + addressToExecutorId -= slaveAddress(executorId) + slaveActor -= executorId + slaveHost -= executorId + freeCores -= executorId + slaveHost -= executorId totalCoreCount.addAndGet(-numCores) - scheduler.slaveLost(slaveId, SlaveLost(reason)) + scheduler.executorLost(executorId, SlaveLost(reason)) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala index aa097fd3a2..b41e951be9 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala @@ -5,7 +5,7 @@ import spark.util.SerializableBuffer private[spark] class TaskDescription( val taskId: Long, - val slaveId: String, + val executorId: String, val name: String, _serializedTask: ByteBuffer) extends Serializable { diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala index ca84503780..0f975ce1eb 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala @@ -4,7 +4,12 @@ package spark.scheduler.cluster * Information about a running task attempt inside a TaskSet. */ private[spark] -class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: String) { +class TaskInfo( + val taskId: Long, + val index: Int, + val launchTime: Long, + val executorId: String, + val host: String) { var finishTime: Long = 0 var failed = false diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index a089b71644..26201ad0dd 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -138,10 +138,11 @@ private[spark] class TaskSetManager( // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the // task must have a preference for this host (or no preferred locations at all). def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { + val hostsAlive = sched.hostsAlive speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set val localTask = speculatableTasks.find { index => - val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive + val locations = tasks(index).preferredLocations.toSet & hostsAlive val attemptLocs = taskAttempts(index).map(_.host) (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host) } @@ -189,7 +190,7 @@ private[spark] class TaskSetManager( } // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[TaskDescription] = { + def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { val time = System.currentTimeMillis val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) @@ -206,11 +207,11 @@ private[spark] class TaskSetManager( } else { "non-preferred, not one of " + task.preferredLocations.mkString(", ") } - logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( - taskSet.id, index, taskId, slaveId, host, prefStr)) + logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( + taskSet.id, index, taskId, execId, host, prefStr)) // Do various bookkeeping copiesRunning(index) += 1 - val info = new TaskInfo(taskId, index, time, host) + val info = new TaskInfo(taskId, index, time, execId, host) taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) if (preferred) { @@ -224,7 +225,7 @@ private[spark] class TaskSetManager( logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) val taskName = "task %s:%d".format(taskSet.id, index) - return Some(new TaskDescription(taskId, slaveId, taskName, serializedTask)) + return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) } case _ => } @@ -356,19 +357,22 @@ private[spark] class TaskSetManager( sched.taskSetFinished(this) } - def hostLost(hostname: String) { - logInfo("Re-queueing tasks for " + hostname + " from TaskSet " + taskSet.id) - // If some task has preferred locations only on hostname, put it in the no-prefs list - // to avoid the wait from delay scheduling - for (index <- getPendingTasksForHost(hostname)) { - val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive - if (newLocs.isEmpty) { - pendingTasksWithNoPrefs += index + def executorLost(execId: String, hostname: String) { + logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) + val newHostsAlive = sched.hostsAlive + // If some task has preferred locations only on hostname, and there are no more executors there, + // put it in the no-prefs list to avoid the wait from delay scheduling + if (!newHostsAlive.contains(hostname)) { + for (index <- getPendingTasksForHost(hostname)) { + val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive + if (newLocs.isEmpty) { + pendingTasksWithNoPrefs += index + } } } - // Re-enqueue any tasks that ran on the failed host if this is a shuffle map stage + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage if (tasks(0).isInstanceOf[ShuffleMapTask]) { - for ((tid, info) <- taskInfos if info.host == hostname) { + for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index if (finished(index)) { finished(index) = false @@ -382,7 +386,7 @@ private[spark] class TaskSetManager( } } // Also re-enqueue any tasks that were running on the node - for ((tid, info) <- taskInfos if info.running && info.host == hostname) { + for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { taskLost(tid, TaskState.KILLED, null) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala index 6b919d68b2..3c3afcbb14 100644 --- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala +++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala @@ -1,8 +1,8 @@ package spark.scheduler.cluster /** - * Represents free resources available on a worker node. + * Represents free resources available on an executor. */ private[spark] -class WorkerOffer(val slaveId: String, val hostname: String, val cores: Int) { +class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) { } diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index 2989e31f5e..f3467db86b 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -268,7 +268,7 @@ private[spark] class MesosSchedulerBackend( synchronized { slaveIdsWithExecutors -= slaveId.getValue } - scheduler.slaveLost(slaveId.getValue, reason) + scheduler.executorLost(slaveId.getValue, reason) } override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 19d35b8667..1215d5f5c8 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -30,6 +30,7 @@ extends Exception(message) private[spark] class BlockManager( + executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, val serializer: Serializer, @@ -68,8 +69,8 @@ class BlockManager( val connectionManager = new ConnectionManager(0) implicit val futureExecContext = connectionManager.futureExecContext - val connectionManagerId = connectionManager.id - val blockManagerId = BlockManagerId(connectionManagerId.host, connectionManagerId.port) + val blockManagerId = BlockManagerId( + executorId, connectionManager.id.host, connectionManager.id.port) // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) @@ -109,8 +110,9 @@ class BlockManager( /** * Construct a BlockManager with a memory limit set based on system properties. */ - def this(actorSystem: ActorSystem, master: BlockManagerMaster, serializer: Serializer) = { - this(actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) + def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, + serializer: Serializer) = { + this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) } /** diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index abb8b45a1f..f2f1e77d41 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -7,27 +7,32 @@ import java.util.concurrent.ConcurrentHashMap * This class represent an unique identifier for a BlockManager. * The first 2 constructors of this class is made private to ensure that * BlockManagerId objects can be created only using the factory method in - * [[spark.storage.BlockManager$]]. This allows de-duplication of id objects. + * [[spark.storage.BlockManager$]]. This allows de-duplication of ID objects. * Also, constructor parameters are private to ensure that parameters cannot * be modified from outside this class. */ private[spark] class BlockManagerId private ( + private var executorId_ : String, private var ip_ : String, private var port_ : Int ) extends Externalizable { - private def this() = this(null, 0) // For deserialization only + private def this() = this(null, null, 0) // For deserialization only - def ip = ip_ + def executorId: String = executorId_ - def port = port_ + def ip: String = ip_ + + def port: Int = port_ override def writeExternal(out: ObjectOutput) { + out.writeUTF(executorId_) out.writeUTF(ip_) out.writeInt(port_) } override def readExternal(in: ObjectInput) { + executorId_ = in.readUTF() ip_ = in.readUTF() port_ = in.readInt() } @@ -35,21 +40,23 @@ private[spark] class BlockManagerId private ( @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = "BlockManagerId(" + ip + ", " + port + ")" + override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port) - override def hashCode = ip.hashCode * 41 + port + override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port override def equals(that: Any) = that match { - case id: BlockManagerId => port == id.port && ip == id.ip - case _ => false + case id: BlockManagerId => + executorId == id.executorId && port == id.port && ip == id.ip + case _ => + false } } private[spark] object BlockManagerId { - def apply(ip: String, port: Int) = - getCachedBlockManagerId(new BlockManagerId(ip, port)) + def apply(execId: String, ip: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(execId, ip, port)) def apply(in: ObjectInput) = { val obj = new BlockManagerId() diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 937115e92c..55ff1dde9c 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -24,7 +24,7 @@ private[spark] class BlockManagerMaster( masterPort: Int) extends Logging { - val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt + val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager" @@ -45,10 +45,10 @@ private[spark] class BlockManagerMaster( } } - /** Remove a dead host from the master actor. This is only called on the master side. */ - def notifyADeadHost(host: String) { - tell(RemoveHost(host)) - logInfo("Removed " + host + " successfully in notifyADeadHost") + /** Remove a dead executor from the master actor. This is only called on the master side. */ + def removeExecutor(execId: String) { + tell(RemoveExecutor(execId)) + logInfo("Removed " + execId + " successfully in removeExecutor") } /** @@ -146,7 +146,7 @@ private[spark] class BlockManagerMaster( } var attempts = 0 var lastException: Exception = null - while (attempts < AKKA_RETRY_ATTEMPS) { + while (attempts < AKKA_RETRY_ATTEMPTS) { attempts += 1 try { val future = masterActor.ask(message)(timeout) diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index b31b6286d3..f88517f1a3 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -23,9 +23,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] - // Mapping from host name to block manager id. We allow multiple block managers - // on the same host name (ip). - private val blockManagerIdByHost = new HashMap[String, ArrayBuffer[BlockManagerId]] + // Mapping from executor ID to block manager ID. + private val blockManagerIdByExecutor = new HashMap[String, BlockManagerId] // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] @@ -74,8 +73,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { case RemoveBlock(blockId) => removeBlock(blockId) - case RemoveHost(host) => - removeHost(host) + case RemoveExecutor(execId) => + removeExecutor(execId) sender ! true case StopBlockManagerMaster => @@ -99,16 +98,12 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { def removeBlockManager(blockManagerId: BlockManagerId) { val info = blockManagerInfo(blockManagerId) - // Remove the block manager from blockManagerIdByHost. If the list of block - // managers belonging to the IP is empty, remove the entry from the hash map. - blockManagerIdByHost.get(blockManagerId.ip).foreach { managers: ArrayBuffer[BlockManagerId] => - managers -= blockManagerId - if (managers.size == 0) blockManagerIdByHost.remove(blockManagerId.ip) - } + // Remove the block manager from blockManagerIdByExecutor. + blockManagerIdByExecutor -= blockManagerId.executorId // Remove it from blockManagerInfo and remove all the blocks. blockManagerInfo.remove(blockManagerId) - var iterator = info.blocks.keySet.iterator + val iterator = info.blocks.keySet.iterator while (iterator.hasNext) { val blockId = iterator.next val locations = blockLocations.get(blockId)._2 @@ -133,17 +128,15 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { toRemove.foreach(removeBlockManager) } - def removeHost(host: String) { - logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.") - logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq) - blockManagerIdByHost.get(host).foreach(_.foreach(removeBlockManager)) - logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq) + def removeExecutor(execId: String) { + logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") + blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) sender ! true } def heartBeat(blockManagerId: BlockManagerId) { if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + if (blockManagerId.executorId == "" && !isLocal) { sender ! true } else { sender ! false @@ -188,24 +181,20 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! res } - private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " - - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { - logInfo("Got Register Msg from master node, don't register it") - } else if (!blockManagerInfo.contains(blockManagerId)) { - blockManagerIdByHost.get(blockManagerId.ip) match { - case Some(managers) => - // A block manager of the same host name already exists. - logInfo("Got another registration for host " + blockManagerId) - managers += blockManagerId + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + if (id.executorId == "" && !isLocal) { + // Got a register message from the master node; don't register it + } else if (!blockManagerInfo.contains(id)) { + blockManagerIdByExecutor.get(id.executorId) match { + case Some(manager) => + // A block manager of the same host name already exists + logError("Got two different block manager registrations on " + id.executorId) + System.exit(1) case None => - blockManagerIdByHost += (blockManagerId.ip -> ArrayBuffer(blockManagerId)) + blockManagerIdByExecutor(id.executorId) = id } - - blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo( - blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor)) + blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo( + id, System.currentTimeMillis(), maxMemSize, slaveActor) } sender ! true } @@ -217,11 +206,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { memSize: Long, diskSize: Long) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " + blockId + " " - if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + if (blockManagerId.executorId == "" && !isLocal) { // We intentionally do not register the master (except in local mode), // so we should not indicate failure. sender ! true @@ -353,8 +339,8 @@ object BlockManagerMasterActor { _lastSeenMs = System.currentTimeMillis() } - def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long) - : Unit = synchronized { + def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, + diskSize: Long) { updateLastSeenMs() diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index 3d03ff3a93..1494f90103 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -88,7 +88,7 @@ private[spark] case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster private[spark] -case class RemoveHost(host: String) extends ToBlockManagerMaster +case class RemoveExecutor(execId: String) extends ToBlockManagerMaster private[spark] case object StopBlockManagerMaster extends ToBlockManagerMaster diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 1003cc7a61..b7423c7234 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -11,6 +11,7 @@ import cc.spray.typeconversion.TwirlSupport._ import scala.collection.mutable.ArrayBuffer import spark.{Logging, SparkContext, SparkEnv} import spark.util.AkkaUtils +import spark.Utils private[spark] @@ -20,10 +21,10 @@ object BlockManagerUI extends Logging { def start(actorSystem : ActorSystem, masterActor: ActorRef, sc: SparkContext) { val webUIDirectives = new BlockManagerUIDirectives(actorSystem, masterActor, sc) try { - logInfo("Starting BlockManager WebUI.") - val port = Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt - AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, + val boundPort = AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", + Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt, webUIDirectives.handler, "BlockManagerHTTPServer") + logInfo("Started BlockManager web UI at %s:%d".format(Utils.localHostName(), boundPort)) } catch { case e: Exception => logError("Failed to create BlockManager WebUI", e) diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala index 689f07b969..f04c046c31 100644 --- a/core/src/main/scala/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/spark/storage/ThreadingTest.scala @@ -78,7 +78,8 @@ private[spark] object ThreadingTest { val masterIp: String = System.getProperty("spark.master.host", "localhost") val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort) - val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024) + val blockManager = new BlockManager( + "", actorSystem, blockManagerMaster, serializer, 1024 * 1024) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index ff2c3079be..775ff8f1aa 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -52,10 +52,10 @@ private[spark] object AkkaUtils { /** * Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to - * handle requests. Throws a SparkException if this fails. + * handle requests. Returns the bound port or throws a SparkException on failure. */ def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, - name: String = "HttpServer") { + name: String = "HttpServer"): Int = { val ioWorker = new IoWorker(actorSystem).start() val httpService = actorSystem.actorOf(Props(new HttpService(route))) val rootService = actorSystem.actorOf(Props(new SprayCanRootService(httpService))) @@ -67,7 +67,7 @@ private[spark] object AkkaUtils { try { Await.result(future, timeout) match { case bound: HttpServer.Bound => - return + return bound.endpoint.getPort case other: Any => throw new SparkException("Failed to bind web UI to port " + port + ": " + other) } diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index bb7c5c01c8..188f8910da 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -63,9 +63,9 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() - override def size(): Int = internalMap.size() + override def size: Int = internalMap.size - override def foreach[U](f: ((A, B)) => U): Unit = { + override def foreach[U](f: ((A, B)) => U) { val iterator = internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala index 70a7c8bc2f..342610e1dd 100644 --- a/core/src/test/scala/spark/DriverSuite.scala +++ b/core/src/test/scala/spark/DriverSuite.scala @@ -13,7 +13,8 @@ class DriverSuite extends FunSuite with Timeouts { val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) forAll(masters) { (master: String) => failAfter(10 seconds) { - Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME"))) + Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), + new File(System.getenv("SPARK_HOME"))) } } } @@ -28,4 +29,4 @@ object DriverWithoutCleanup { val sc = new SparkContext(args(0), "DriverWithoutCleanup") sc.parallelize(1 to 100, 4).count() } -} \ No newline at end of file +} diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 7d5305f1e0..e8fe7ecabc 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -43,13 +43,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((BlockManagerId("hostA", 1000), size1000), - (BlockManagerId("hostB", 1000), size10000))) + assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), + (BlockManagerId("b", "hostB", 1000), size10000))) tracker.stop() } @@ -61,47 +61,52 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) // As if we had two simulatenous fetch failures - tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) - tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) - // The remaining reduce task might try to grab the output dispite the shuffle failure; + // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } } test("remote fetch") { - val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("test", "localhost", 0) - System.setProperty("spark.master.port", boundPort.toString) - val masterTracker = new MapOutputTracker(actorSystem, true) - val slaveTracker = new MapOutputTracker(actorSystem, false) - masterTracker.registerShuffle(10, 1) - masterTracker.incrementGeneration() - slaveTracker.updateGeneration(masterTracker.getGeneration) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + try { + System.clearProperty("spark.master.host") // In case some previous test had set it + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem("test", "localhost", 0) + System.setProperty("spark.master.port", boundPort.toString) + val masterTracker = new MapOutputTracker(actorSystem, true) + val slaveTracker = new MapOutputTracker(actorSystem, false) + masterTracker.registerShuffle(10, 1) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("hostA", 1000), Array(compressedSize1000))) - masterTracker.incrementGeneration() - slaveTracker.updateGeneration(masterTracker.getGeneration) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("hostA", 1000), size1000))) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000), size1000))) - masterTracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) - masterTracker.incrementGeneration() - slaveTracker.updateGeneration(masterTracker.getGeneration) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - // failure should be cached - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + // failure should be cached + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + } finally { + System.clearProperty("spark.master.port") + } } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 2165744689..2d177bbf67 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -86,9 +86,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("BlockManagerId object caching") { - val id1 = BlockManagerId("XXX", 1) - val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1 - val id3 = BlockManagerId("XXX", 2) // this should return a different object + val id1 = BlockManagerId("e1", "XXX", 1) + val id2 = BlockManagerId("e1", "XXX", 1) // this should return the same object as id1 + val id3 = BlockManagerId("e1", "XXX", 2) // this should return a different object assert(id2 === id1, "id2 is not same as id1") assert(id2.eq(id1), "id2 is not the same object as id1") assert(id3 != id1, "id3 is same as id1") @@ -103,7 +103,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 1 manager interaction") { - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -133,8 +133,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 2 managers interaction") { - store = new BlockManager(actorSystem, master, serializer, 2000) - store2 = new BlockManager(actorSystem, master, new KryoSerializer, 2000) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000) + store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -149,7 +149,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing block") { - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -198,7 +198,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -206,7 +206,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.getSingle("a1") != None, "a1 was not in store") assert(master.getLocations("a1").size > 0, "master was not told about a1") - master.notifyADeadHost(store.blockManagerId.ip) + master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") store invokePrivate heartBeat() @@ -214,14 +214,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("reregistration on block update") { - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) assert(master.getLocations("a1").size > 0, "master was not told about a1") - master.notifyADeadHost(store.blockManagerId.ip) + master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY) @@ -233,35 +233,35 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration doesn't dead lock") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) // try many times to trigger any deadlocks for (i <- 1 to 100) { - master.notifyADeadHost(store.blockManagerId.ip) + master.removeExecutor(store.blockManagerId.executorId) val t1 = new Thread { - override def run = { + override def run() { store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true) } } val t2 = new Thread { - override def run = { + override def run() { store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) } } val t3 = new Thread { - override def run = { + override def run() { store invokePrivate heartBeat() } } - t1.start - t2.start - t3.start - t1.join - t2.join - t3.join + t1.start() + t2.start() + t3.start() + t1.join() + t2.join() + t3.join() store.dropFromMemory("a1", null) store.dropFromMemory("a2", null) @@ -270,7 +270,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -289,7 +289,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage with serialization") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -308,14 +308,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of same RDD") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY) store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY) store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY) - // Even though we accessed rdd_0_3 last, it should not have replaced partitiosn 1 and 2 + // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2 // from the same RDD assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") @@ -327,7 +327,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) @@ -350,7 +350,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("on-disk storage") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -363,7 +363,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -378,7 +378,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with getLocalBytes") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -393,7 +393,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -408,7 +408,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -423,7 +423,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -448,7 +448,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU with streams") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -472,7 +472,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels and streams") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -518,7 +518,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("overly large block") { - store = new BlockManager(actorSystem, master, serializer, 500) + store = new BlockManager("", actorSystem, master, serializer, 500) store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) @@ -529,49 +529,49 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block compression") { try { System.setProperty("spark.shuffle.compress", "true") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000) store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed") store.stop() store = null System.setProperty("spark.shuffle.compress", "false") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec2", actorSystem, master, serializer, 2000) store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed") store.stop() store = null System.setProperty("spark.broadcast.compress", "true") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec3", actorSystem, master, serializer, 2000) store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed") store.stop() store = null System.setProperty("spark.broadcast.compress", "false") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec4", actorSystem, master, serializer, 2000) store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed") store.stop() store = null System.setProperty("spark.rdd.compress", "true") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec5", actorSystem, master, serializer, 2000) store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed") store.stop() store = null System.setProperty("spark.rdd.compress", "false") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec6", actorSystem, master, serializer, 2000) store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec7", actorSystem, master, serializer, 2000) store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") store.stop() diff --git a/sbt/sbt b/sbt/sbt index a3055c13c1..8f426d18e8 100755 --- a/sbt/sbt +++ b/sbt/sbt @@ -5,4 +5,4 @@ if [ "$MESOS_HOME" != "" ]; then fi export SPARK_HOME=$(cd "$(dirname $0)/.."; pwd) export SPARK_TESTING=1 # To put test classes on classpath -java -Xmx1200M -XX:MaxPermSize=200m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@" +java -Xmx1200M -XX:MaxPermSize=250m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@" -- cgit v1.2.3 From 909850729ec59b788645575fdc03df7cc51fe42b Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 27 Jan 2013 23:17:20 -0800 Subject: Rename more things from slave to executor --- .../scala/spark/deploy/worker/ExecutorRunner.scala | 2 +- .../spark/executor/StandaloneExecutorBackend.scala | 12 +++--- .../spark/scheduler/cluster/SlaveResources.scala | 4 -- .../cluster/SparkDeploySchedulerBackend.scala | 16 ++------ .../cluster/StandaloneClusterMessage.scala | 16 ++++---- .../cluster/StandaloneSchedulerBackend.scala | 48 +++++++++++----------- .../main/scala/spark/storage/BlockManagerUI.scala | 2 + .../main/scala/spark/util/MetadataCleaner.scala | 10 ++--- 8 files changed, 50 insertions(+), 60 deletions(-) delete mode 100644 core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index af3acfecb6..f5ff267d44 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -65,7 +65,7 @@ private[spark] class ExecutorRunner( } } - /** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */ + /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => hostname diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 435ee5743e..50871802ea 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -8,10 +8,10 @@ import akka.actor.{ActorRef, Actor, Props} import java.util.concurrent.{TimeUnit, ThreadPoolExecutor, SynchronousQueue} import akka.remote.RemoteClientLifeCycleEvent import spark.scheduler.cluster._ -import spark.scheduler.cluster.RegisteredSlave +import spark.scheduler.cluster.RegisteredExecutor import spark.scheduler.cluster.LaunchTask -import spark.scheduler.cluster.RegisterSlaveFailed -import spark.scheduler.cluster.RegisterSlave +import spark.scheduler.cluster.RegisterExecutorFailed +import spark.scheduler.cluster.RegisterExecutor private[spark] class StandaloneExecutorBackend( @@ -30,7 +30,7 @@ private[spark] class StandaloneExecutorBackend( try { logInfo("Connecting to master: " + masterUrl) master = context.actorFor(masterUrl) - master ! RegisterSlave(executorId, hostname, cores) + master ! RegisterExecutor(executorId, hostname, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing } catch { @@ -41,11 +41,11 @@ private[spark] class StandaloneExecutorBackend( } override def receive = { - case RegisteredSlave(sparkProperties) => + case RegisteredExecutor(sparkProperties) => logInfo("Successfully registered with master") executor.initialize(executorId, hostname, sparkProperties) - case RegisterSlaveFailed(message) => + case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) System.exit(1) diff --git a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala b/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala deleted file mode 100644 index 96ebaa4601..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala +++ /dev/null @@ -1,4 +0,0 @@ -package spark.scheduler.cluster - -private[spark] -class SlaveResources(val slaveId: String, val hostname: String, val coresFree: Int) {} diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index f0792c1b76..6dd3ae003d 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,7 +19,6 @@ private[spark] class SparkDeploySchedulerBackend( var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt - val executorIdToSlaveId = new HashMap[String, String] // Memory used by each executor (in megabytes) val executorMemory = { @@ -47,7 +46,7 @@ private[spark] class SparkDeploySchedulerBackend( } override def stop() { - stopping = true; + stopping = true super.stop() client.stop() if (shutdownCallback != null) { @@ -67,23 +66,16 @@ private[spark] class SparkDeploySchedulerBackend( } def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) { - executorIdToSlaveId += id -> workerId logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format( id, host, cores, Utils.memoryMegabytesToString(memory))) } - def executorRemoved(id: String, message: String, exitStatus: Option[Int]) { + def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) { val reason: ExecutorLossReason = exitStatus match { case Some(code) => ExecutorExited(code) case None => SlaveLost(message) } - logInfo("Executor %s removed: %s".format(id, message)) - executorIdToSlaveId.get(id) match { - case Some(slaveId) => - executorIdToSlaveId.remove(id) - scheduler.executorLost(slaveId, reason) - case None => - logInfo("No slave ID known for executor %s".format(id)) - } + logInfo("Executor %s removed: %s".format(executorId, message)) + scheduler.executorLost(executorId, reason) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala index 1386cd9d44..c68f15bdfa 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -11,24 +11,26 @@ private[spark] case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage private[spark] -case class RegisteredSlave(sparkProperties: Seq[(String, String)]) extends StandaloneClusterMessage +case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) + extends StandaloneClusterMessage private[spark] -case class RegisterSlaveFailed(message: String) extends StandaloneClusterMessage +case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage -// Slaves to master +// Executors to master private[spark] -case class RegisterSlave(slaveId: String, host: String, cores: Int) extends StandaloneClusterMessage +case class RegisterExecutor(executorId: String, host: String, cores: Int) + extends StandaloneClusterMessage private[spark] -case class StatusUpdate(slaveId: String, taskId: Long, state: TaskState, data: SerializableBuffer) +case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer) extends StandaloneClusterMessage private[spark] object StatusUpdate { /** Alternate factory method that takes a ByteBuffer directly for the data field */ - def apply(slaveId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = { - StatusUpdate(slaveId, taskId, state, new SerializableBuffer(data)) + def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = { + StatusUpdate(executorId, taskId, state, new SerializableBuffer(data)) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 32be1e7a26..69822f568c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -24,9 +24,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor var totalCoreCount = new AtomicInteger(0) class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor { - val slaveActor = new HashMap[String, ActorRef] - val slaveAddress = new HashMap[String, Address] - val slaveHost = new HashMap[String, String] + val executorActor = new HashMap[String, ActorRef] + val executorAddress = new HashMap[String, Address] + val executorHost = new HashMap[String, String] val freeCores = new HashMap[String, Int] val actorToExecutorId = new HashMap[ActorRef, String] val addressToExecutorId = new HashMap[Address, String] @@ -37,17 +37,17 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } def receive = { - case RegisterSlave(executorId, host, cores) => - if (slaveActor.contains(executorId)) { - sender ! RegisterSlaveFailed("Duplicate executor ID: " + executorId) + case RegisterExecutor(executorId, host, cores) => + if (executorActor.contains(executorId)) { + sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) } else { logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredSlave(sparkProperties) + sender ! RegisteredExecutor(sparkProperties) context.watch(sender) - slaveActor(executorId) = sender - slaveHost(executorId) = host + executorActor(executorId) = sender + executorHost(executorId) = host freeCores(executorId) = cores - slaveAddress(executorId) = sender.path.address + executorAddress(executorId) = sender.path.address actorToExecutorId(sender) = executorId addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) @@ -69,45 +69,45 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor context.stop(self) case Terminated(actor) => - actorToExecutorId.get(actor).foreach(removeSlave(_, "Akka actor terminated")) + actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated")) case RemoteClientDisconnected(transport, address) => - addressToExecutorId.get(address).foreach(removeSlave(_, "remote Akka client disconnected")) + addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected")) case RemoteClientShutdown(transport, address) => - addressToExecutorId.get(address).foreach(removeSlave(_, "remote Akka client shutdown")) + addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown")) } - // Make fake resource offers on all slaves + // Make fake resource offers on all executors def makeOffers() { launchTasks(scheduler.resourceOffers( - slaveHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) + executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) } - // Make fake resource offers on just one slave + // Make fake resource offers on just one executor def makeOffers(executorId: String) { launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, slaveHost(executorId), freeCores(executorId))))) + Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId))))) } // Launch tasks returned by a set of resource offers def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { freeCores(task.executorId) -= 1 - slaveActor(task.executorId) ! LaunchTask(task) + executorActor(task.executorId) ! LaunchTask(task) } } // Remove a disconnected slave from the cluster - def removeSlave(executorId: String, reason: String) { + def removeExecutor(executorId: String, reason: String) { logInfo("Slave " + executorId + " disconnected, so removing it") val numCores = freeCores(executorId) - actorToExecutorId -= slaveActor(executorId) - addressToExecutorId -= slaveAddress(executorId) - slaveActor -= executorId - slaveHost -= executorId + actorToExecutorId -= executorActor(executorId) + addressToExecutorId -= executorAddress(executorId) + executorActor -= executorId + executorHost -= executorId freeCores -= executorId - slaveHost -= executorId + executorHost -= executorId totalCoreCount.addAndGet(-numCores) scheduler.executorLost(executorId, SlaveLost(reason)) } diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index b7423c7234..956ede201e 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -21,6 +21,8 @@ object BlockManagerUI extends Logging { def start(actorSystem : ActorSystem, masterActor: ActorRef, sc: SparkContext) { val webUIDirectives = new BlockManagerUIDirectives(actorSystem, masterActor, sc) try { + // TODO: This needs to find a random free port to bind to. Unfortunately, there's no way + // in spray to do that, so we'll have to rely on something like new ServerSocket() val boundPort = AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt, webUIDirectives.handler, "BlockManagerHTTPServer") diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 139e21d09e..721c4c6029 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -14,18 +14,16 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging val task = new TimerTask { def run() { try { - if (delaySeconds > 0) { - cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) - logInfo("Ran metadata cleaner for " + name) - } + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) + logInfo("Ran metadata cleaner for " + name) } catch { case e: Exception => logError("Error running cleanup task for " + name, e) } } } - if (periodSeconds > 0) { - logInfo( + if (delaySeconds > 0) { + logDebug( "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and " + "period of " + periodSeconds + " secs") timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) -- cgit v1.2.3 From f03d9760fd8ac67fd0865cb355ba75d2eff507fe Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 27 Jan 2013 23:56:14 -0800 Subject: Clean up BlockManagerUI a little (make it not be an object, merge with Directives, and bind to a random port) --- core/src/main/scala/spark/SparkContext.scala | 7 +- core/src/main/scala/spark/Utils.scala | 17 ++- .../scala/spark/deploy/master/MasterWebUI.scala | 6 +- .../scala/spark/deploy/worker/WorkerWebUI.scala | 6 +- .../main/scala/spark/storage/BlockManagerUI.scala | 120 ++++++++++----------- core/src/main/scala/spark/util/AkkaUtils.scala | 6 +- .../main/scala/spark/util/MetadataCleaner.scala | 3 + 7 files changed, 91 insertions(+), 74 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 39721b47ae..77036c1275 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -44,6 +44,7 @@ import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} +import storage.BlockManagerUI import util.{MetadataCleaner, TimeStampedHashMap} /** @@ -88,8 +89,9 @@ class SparkContext( SparkEnv.set(env) // Start the BlockManager UI - spark.storage.BlockManagerUI.start(SparkEnv.get.actorSystem, - SparkEnv.get.blockManager.master.masterActor, this) + private[spark] val ui = new BlockManagerUI( + env.actorSystem, env.blockManager.master.masterActor, this) + ui.start() // Used to store a URL for each static file/jar together with the file's local timestamp private[spark] val addedFiles = HashMap[String, Long]() @@ -97,7 +99,6 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() - private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index ae77264372..1e58d01273 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,7 +1,7 @@ package spark import java.io._ -import java.net.{NetworkInterface, InetAddress, Inet4Address, URL, URI} +import java.net._ import java.util.{Locale, Random, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration @@ -11,6 +11,7 @@ import scala.collection.JavaConversions._ import scala.io.Source import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder +import scala.Some /** * Various utility methods used by Spark. @@ -431,4 +432,18 @@ private object Utils extends Logging { } "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine) } + + /** + * Try to find a free port to bind to on the local host. This should ideally never be needed, + * except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray) + * don't let users bind to port 0 and then figure out which free port they actually bound to. + * We work around this by binding a ServerSocket and immediately unbinding it. This is *not* + * necessarily guaranteed to work, but it's the best we can do. + */ + def findFreePort(): Int = { + val socket = new ServerSocket(0) + val portBound = socket.getLocalPort + socket.close() + portBound + } } diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 458ee2d665..a01774f511 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -14,12 +14,15 @@ import cc.spray.typeconversion.SprayJsonSupport._ import spark.deploy._ import spark.deploy.JsonProtocol._ +/** + * Web UI server for the standalone master. + */ private[spark] class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/master/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(1 seconds) + implicit val timeout = Timeout(10 seconds) val handler = { get { @@ -76,5 +79,4 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct getFromResourceDirectory(RESOURCE_DIR) } } - } diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index f9489d99fc..ef81f072a3 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -13,12 +13,15 @@ import cc.spray.typeconversion.SprayJsonSupport._ import spark.deploy.{WorkerState, RequestWorkerState} import spark.deploy.JsonProtocol._ +/** + * Web UI server for the standalone worker. + */ private[spark] class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/worker/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(1 seconds) + implicit val timeout = Timeout(10 seconds) val handler = { get { @@ -50,5 +53,4 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct getFromResourceDirectory(RESOURCE_DIR) } } - } diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 956ede201e..eda320fa47 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -1,32 +1,41 @@ package spark.storage import akka.actor.{ActorRef, ActorSystem} -import akka.dispatch.Await import akka.pattern.ask import akka.util.Timeout import akka.util.duration._ -import cc.spray.Directives import cc.spray.directives._ import cc.spray.typeconversion.TwirlSupport._ +import cc.spray.Directives import scala.collection.mutable.ArrayBuffer -import spark.{Logging, SparkContext, SparkEnv} +import spark.{Logging, SparkContext} import spark.util.AkkaUtils import spark.Utils +/** + * Web UI server for the BlockManager inside each SparkContext. + */ private[spark] -object BlockManagerUI extends Logging { +class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, sc: SparkContext) + extends Directives with Logging { + + val STATIC_RESOURCE_DIR = "spark/deploy/static" + + implicit val timeout = Timeout(10 seconds) - /* Starts the Web interface for the BlockManager */ - def start(actorSystem : ActorSystem, masterActor: ActorRef, sc: SparkContext) { - val webUIDirectives = new BlockManagerUIDirectives(actorSystem, masterActor, sc) + /** Start a HTTP server to run the Web interface */ + def start() { try { - // TODO: This needs to find a random free port to bind to. Unfortunately, there's no way - // in spray to do that, so we'll have to rely on something like new ServerSocket() - val boundPort = AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", - Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt, - webUIDirectives.handler, "BlockManagerHTTPServer") - logInfo("Started BlockManager web UI at %s:%d".format(Utils.localHostName(), boundPort)) + val port = if (System.getProperty("spark.ui.port") != null) { + System.getProperty("spark.ui.port").toInt + } else { + // TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which + // random port it bound to, so we have to try to find a local one by creating a socket. + Utils.findFreePort() + } + AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler, "BlockManagerHTTPServer") + logInfo("Started BlockManager web UI at http://%s:%d".format(Utils.localHostName(), port)) } catch { case e: Exception => logError("Failed to create BlockManager WebUI", e) @@ -34,58 +43,43 @@ object BlockManagerUI extends Logging { } } -} - - -private[spark] -class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, - sc: SparkContext) extends Directives { - - val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(1 seconds) - val handler = { - - get { path("") { completeWith { - // Request the current storage status from the Master - val future = master ? GetStorageStatus - future.map { status => - val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray - - // Calculate macro-level statistics - val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) - val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) - val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)) - .reduceOption(_+_).getOrElse(0L) - - val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) - - spark.storage.html.index. - render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) - } - }}} ~ - get { path("rdd") { parameter("id") { id => { completeWith { - val future = master ? GetStorageStatus - future.map { status => - val prefix = "rdd_" + id.toString - - - val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray - val filteredStorageStatusList = StorageUtils. - filterStorageStatusByPrefix(storageStatusList, prefix) - - val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head - - spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) - + get { + path("") { + completeWith { + // Request the current storage status from the Master + val future = blockManagerMaster ? GetStorageStatus + future.map { status => + // Calculate macro-level statistics + val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray + val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) + val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) + val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)) + .reduceOption(_+_).getOrElse(0L) + val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) + spark.storage.html.index. + render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) + } + } + } ~ + path("rdd") { + parameter("id") { id => + completeWith { + val future = blockManagerMaster ? GetStorageStatus + future.map { status => + val prefix = "rdd_" + id.toString + val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray + val filteredStorageStatusList = StorageUtils. + filterStorageStatusByPrefix(storageStatusList, prefix) + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head + spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) + } + } + } + } ~ + pathPrefix("static") { + getFromResourceDirectory(STATIC_RESOURCE_DIR) } - }}}}} ~ - pathPrefix("static") { - getFromResourceDirectory(STATIC_RESOURCE_DIR) } - } - - - } diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 775ff8f1aa..e0fdeffbc4 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -1,6 +1,6 @@ package spark.util -import akka.actor.{Props, ActorSystemImpl, ActorSystem} +import akka.actor.{ActorRef, Props, ActorSystemImpl, ActorSystem} import com.typesafe.config.ConfigFactory import akka.util.duration._ import akka.pattern.ask @@ -55,7 +55,7 @@ private[spark] object AkkaUtils { * handle requests. Returns the bound port or throws a SparkException on failure. */ def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, - name: String = "HttpServer"): Int = { + name: String = "HttpServer"): ActorRef = { val ioWorker = new IoWorker(actorSystem).start() val httpService = actorSystem.actorOf(Props(new HttpService(route))) val rootService = actorSystem.actorOf(Props(new SprayCanRootService(httpService))) @@ -67,7 +67,7 @@ private[spark] object AkkaUtils { try { Await.result(future, timeout) match { case bound: HttpServer.Bound => - return bound.endpoint.getPort + return server case other: Any => throw new SparkException("Failed to bind web UI to port " + port + ": " + other) } diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 721c4c6029..51fb440108 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -5,6 +5,9 @@ import java.util.{TimerTask, Timer} import spark.Logging +/** + * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) + */ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { val delaySeconds = MetadataCleaner.getDelaySeconds -- cgit v1.2.3 From 286f8f876ff495df33a7966e77ca90d69f338450 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 28 Jan 2013 01:29:27 -0800 Subject: Change time unit in MetadataCleaner to seconds --- core/src/main/scala/spark/util/MetadataCleaner.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 51fb440108..6cf93a9b17 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -9,7 +9,6 @@ import spark.Logging * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) */ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delaySeconds = MetadataCleaner.getDelaySeconds val periodSeconds = math.max(10, delaySeconds / 10) val timer = new Timer(name + " cleanup timer", true) @@ -39,7 +38,7 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging object MetadataCleaner { - def getDelaySeconds = (System.getProperty("spark.cleaner.delay", "-100").toDouble * 60).toInt - def setDelaySeconds(delay: Long) { System.setProperty("spark.cleaner.delay", delay.toString) } + def getDelaySeconds = System.getProperty("spark.cleaner.delay", "-1").toInt + def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.delay", delay.toString) } } -- cgit v1.2.3 From 07f568e1bfc67eead88e2c5dbfb9cac23e1ac8bc Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 24 Jan 2013 15:27:29 -0800 Subject: SPARK-658: Adding logging of stage duration --- .../main/scala/spark/scheduler/DAGScheduler.scala | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index bd541d4207..8aad667182 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -86,6 +86,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] + val stageSubmissionTimes = new HashMap[Stage, Long] val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) @@ -393,6 +394,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with logDebug("New pending tasks: " + myPending) taskSched.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority)) + if (!stageSubmissionTimes.contains(stage)) { + stageSubmissionTimes.put(stage, System.currentTimeMillis()) + } } else { logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) @@ -407,6 +411,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def handleTaskCompletion(event: CompletionEvent) { val task = event.task val stage = idToStage(task.stageId) + + def stageFinished(stage: Stage) = { + val serviceTime = stageSubmissionTimes.remove(stage) match { + case Some(t) => (System.currentTimeMillis() - t).toString + case _ => "Unkown" + } + logInfo("%s (%s) finished in %s ms".format(stage, stage.origin, serviceTime)) + running -= stage + } event.reason match { case Success => logInfo("Completed " + task) @@ -421,13 +434,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (!job.finished(rt.outputId)) { job.finished(rt.outputId) = true job.numFinished += 1 - job.listener.taskSucceeded(rt.outputId, event.result) // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { activeJobs -= job resultStageToJob -= stage - running -= stage + stageFinished(stage) } + job.listener.taskSucceeded(rt.outputId, event.result) } case None => logInfo("Ignoring result from " + rt + " because its job has finished") @@ -444,8 +457,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with stage.addOutputLoc(smt.partition, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { - logInfo(stage + " (" + stage.origin + ") finished; looking for newly runnable stages") - running -= stage + stageFinished(stage) + logInfo("looking for newly runnable stages") logInfo("running: " + running) logInfo("waiting: " + waiting) logInfo("failed: " + failed) -- cgit v1.2.3 From c423be7d8e1349fc00431328b76b52f4eee8a975 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 24 Jan 2013 18:25:57 -0800 Subject: Renaming stage finished function --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 8aad667182..bce7418e87 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -412,7 +412,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val task = event.task val stage = idToStage(task.stageId) - def stageFinished(stage: Stage) = { + def markStageAsFinished(stage: Stage) = { val serviceTime = stageSubmissionTimes.remove(stage) match { case Some(t) => (System.currentTimeMillis() - t).toString case _ => "Unkown" @@ -438,7 +438,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (job.numFinished == job.numPartitions) { activeJobs -= job resultStageToJob -= stage - stageFinished(stage) + markStageAsFinished(stage) } job.listener.taskSucceeded(rt.outputId, event.result) } @@ -457,7 +457,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with stage.addOutputLoc(smt.partition, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { - stageFinished(stage) + markStageAsFinished(stage) logInfo("looking for newly runnable stages") logInfo("running: " + running) logInfo("waiting: " + waiting) -- cgit v1.2.3 From 501433f1d59b1b326c0a7169fa1fd6136f7628e3 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 28 Jan 2013 10:17:35 -0800 Subject: Making submission time a field --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 7 +++---- core/src/main/scala/spark/scheduler/Stage.scala | 3 +++ 2 files changed, 6 insertions(+), 4 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index bce7418e87..7ba1f3430a 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -86,7 +86,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] - val stageSubmissionTimes = new HashMap[Stage, Long] val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) @@ -394,8 +393,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with logDebug("New pending tasks: " + myPending) taskSched.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority)) - if (!stageSubmissionTimes.contains(stage)) { - stageSubmissionTimes.put(stage, System.currentTimeMillis()) + if (!stage.submissionTime.isDefined) { + stage.submissionTime = Some(System.currentTimeMillis()) } } else { logDebug("Stage " + stage + " is actually done; %b %d %d".format( @@ -413,7 +412,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val stage = idToStage(task.stageId) def markStageAsFinished(stage: Stage) = { - val serviceTime = stageSubmissionTimes.remove(stage) match { + val serviceTime = stage.submissionTime match { case Some(t) => (System.currentTimeMillis() - t).toString case _ => "Unkown" } diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index e9419728e3..374114d870 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -32,6 +32,9 @@ private[spark] class Stage( val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) var numAvailableOutputs = 0 + /** When first task was submitted to scheduler. */ + var submissionTime: Option[Long] = None + private var nextAttemptId = 0 def isAvailable: Boolean = { -- cgit v1.2.3 From a423ee546c389b5ce0d2117299456712370d7ad1 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 22 Jan 2013 18:48:43 -0800 Subject: expose RDD & storage info directly via SparkContext --- core/src/main/scala/spark/SparkContext.scala | 16 +++++++++ .../scala/spark/storage/BlockManagerMaster.scala | 4 +++ .../main/scala/spark/storage/BlockManagerUI.scala | 39 +++++++++------------- .../main/scala/spark/storage/StorageUtils.scala | 10 +++--- 4 files changed, 41 insertions(+), 28 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 77036c1275..be992250a9 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -46,6 +46,7 @@ import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, C import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import storage.BlockManagerUI import util.{MetadataCleaner, TimeStampedHashMap} +import storage.{StorageStatus, StorageUtils, RDDInfo} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -473,6 +474,21 @@ class SparkContext( } } + /** + * Return information about what RDDs are cached, if they are in mem or on disk, how much space + * they take, etc. + */ + def getRDDStorageInfo : Array[RDDInfo] = { + StorageUtils.rddInfoFromStorageStatus(getSlavesStorageStatus, this) + } + + /** + * Return information about blocks stored in all of the slaves + */ + def getSlavesStorageStatus : Array[StorageStatus] = { + env.blockManager.master.getStorageStatus + } + /** * Clear the job's list of files added by `addFile` so that they do not get downloaded to * any new nodes. diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 55ff1dde9c..c7ee76f0b7 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -118,6 +118,10 @@ private[spark] class BlockManagerMaster( askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } + def getStorageStatus: Array[StorageStatus] = { + askMasterWithRetry[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray + } + /** Stop the master actor, called only on the Spark master node */ def stop() { if (masterActor != null) { diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index eda320fa47..52f6d1b657 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -1,13 +1,10 @@ package spark.storage import akka.actor.{ActorRef, ActorSystem} -import akka.pattern.ask import akka.util.Timeout import akka.util.duration._ -import cc.spray.directives._ import cc.spray.typeconversion.TwirlSupport._ import cc.spray.Directives -import scala.collection.mutable.ArrayBuffer import spark.{Logging, SparkContext} import spark.util.AkkaUtils import spark.Utils @@ -48,32 +45,26 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, path("") { completeWith { // Request the current storage status from the Master - val future = blockManagerMaster ? GetStorageStatus - future.map { status => - // Calculate macro-level statistics - val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray - val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) - val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) - val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)) - .reduceOption(_+_).getOrElse(0L) - val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) - spark.storage.html.index. - render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) - } + val storageStatusList = sc.getSlavesStorageStatus + // Calculate macro-level statistics + val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) + val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) + val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)) + .reduceOption(_+_).getOrElse(0L) + val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) + spark.storage.html.index. + render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) } } ~ path("rdd") { parameter("id") { id => completeWith { - val future = blockManagerMaster ? GetStorageStatus - future.map { status => - val prefix = "rdd_" + id.toString - val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray - val filteredStorageStatusList = StorageUtils. - filterStorageStatusByPrefix(storageStatusList, prefix) - val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head - spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) - } + val prefix = "rdd_" + id.toString + val storageStatusList = sc.getSlavesStorageStatus + val filteredStorageStatusList = StorageUtils. + filterStorageStatusByPrefix(storageStatusList, prefix) + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head + spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) } } } ~ diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index a10e3a95c6..d6e33c8619 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -56,9 +56,11 @@ object StorageUtils { // Find the id of the RDD, e.g. rdd_1 => 1 val rddId = rddKey.split("_").last.toInt // Get the friendly name for the rdd, if available. - val rddName = Option(sc.persistentRdds(rddId).name).getOrElse(rddKey) - val rddStorageLevel = sc.persistentRdds(rddId).getStorageLevel - + val rdd = sc.persistentRdds(rddId) + val rddName = Option(rdd.name).getOrElse(rddKey) + val rddStorageLevel = rdd.getStorageLevel + //TODO get total number of partitions in rdd + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize) }.toArray } @@ -75,4 +77,4 @@ object StorageUtils { } -} \ No newline at end of file +} -- cgit v1.2.3 From 0f22c4207f27bc8d1675af82f873141dda754f5c Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 28 Jan 2013 10:08:59 -0800 Subject: better formatting for RDDInfo --- core/src/main/scala/spark/storage/StorageUtils.scala | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index d6e33c8619..ce7c067eea 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -1,6 +1,6 @@ package spark.storage -import spark.SparkContext +import spark.{Utils, SparkContext} import BlockManagerMasterActor.BlockStatus private[spark] @@ -22,8 +22,14 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, } case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numPartitions: Int, memSize: Long, diskSize: Long) - + numPartitions: Int, memSize: Long, diskSize: Long) { + override def toString = { + import Utils.memoryBytesToString + import java.lang.{Integer => JInt} + String.format("RDD \"%s\" (%d) Storage: %s; Partitions: %d; MemorySize: %s; DiskSize: %s", name, id.asInstanceOf[JInt], + storageLevel.toString, numPartitions.asInstanceOf[JInt], memoryBytesToString(memSize), memoryBytesToString(diskSize)) + } +} /* Helper methods for storage-related objects */ private[spark] -- cgit v1.2.3 From efff7bfb3382f4e07f9fad0e6e647c0ec629355e Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 28 Jan 2013 20:23:11 -0800 Subject: add long and float accumulatorparams --- core/src/main/scala/spark/SparkContext.scala | 10 ++++++++++ core/src/test/scala/spark/AccumulatorSuite.scala | 6 ++++++ 2 files changed, 16 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 77036c1275..dc9b8688b3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -673,6 +673,16 @@ object SparkContext { def zero(initialValue: Int) = 0 } + implicit object LongAccumulatorParam extends AccumulatorParam[Long] { + def addInPlace(t1: Long, t2: Long) = t1 + t2 + def zero(initialValue: Long) = 0l + } + + implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { + def addInPlace(t1: Float, t2: Float) = t1 + t2 + def zero(initialValue: Float) = 0f + } + // TODO: Add AccumulatorParams for other types, e.g. lists and strings implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala index 78d64a44ae..ac8ae7d308 100644 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -17,6 +17,12 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte val d = sc.parallelize(1 to 20) d.foreach{x => acc += x} acc.value should be (210) + + + val longAcc = sc.accumulator(0l) + val maxInt = Integer.MAX_VALUE.toLong + d.foreach{x => longAcc += maxInt + x} + longAcc.value should be (210l + maxInt * 20) } test ("value not assignable from tasks") { -- cgit v1.2.3 From 1f9b486a8be49ef547ac1532cafd63c4c9d4ddda Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 28 Jan 2013 20:24:54 -0800 Subject: Some DEBUG-level log cleanup. A few changes to make the DEBUG-level logs less noisy and more readable. - Moved a few very frequent messages to Trace - Changed some BlockManger log messages to make them more understandable SPARK-666 #resolve --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 8 ++++---- core/src/main/scala/spark/storage/BlockManager.scala | 14 +++++++------- .../main/scala/spark/storage/BlockManagerMasterActor.scala | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index bd541d4207..f10d7cc84e 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -308,10 +308,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } else { // TODO: We might want to run this less often, when we are sure that something has become // runnable that wasn't before. - logDebug("Checking for newly runnable parent stages") - logDebug("running: " + running) - logDebug("waiting: " + waiting) - logDebug("failed: " + failed) + logTrace("Checking for newly runnable parent stages") + logTrace("running: " + running) + logTrace("waiting: " + waiting) + logTrace("failed: " + failed) val waiting2 = waiting.toArray waiting.clear() for (stage <- waiting2.sortBy(_.priority)) { diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 1215d5f5c8..c61fd75c2b 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -243,7 +243,7 @@ class BlockManager( val startTimeMs = System.currentTimeMillis var managers = master.getLocations(blockId) val locations = managers.map(_.ip) - logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -253,7 +253,7 @@ class BlockManager( def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray - logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -645,7 +645,7 @@ class BlockManager( var size = 0L myInfo.synchronized { - logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") if (level.useMemory) { @@ -677,8 +677,10 @@ class BlockManager( } logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) + // Replicate block if required if (level.replication > 1) { + val remoteStartTime = System.currentTimeMillis // Serialize the block if not already done if (bytesAfterPut == null) { if (valuesAfterPut == null) { @@ -688,12 +690,10 @@ class BlockManager( bytesAfterPut = dataSerialize(blockId, valuesAfterPut) } replicate(blockId, bytesAfterPut, level) + logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime)) } - BlockManager.dispose(bytesAfterPut) - logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)) - return size } @@ -978,7 +978,7 @@ object BlockManager extends Logging { */ def dispose(buffer: ByteBuffer) { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - logDebug("Unmapping " + buffer) + logTrace("Unmapping " + buffer) if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { buffer.asInstanceOf[DirectBuffer].cleaner().clean() } diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index f88517f1a3..2830bc6297 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -115,7 +115,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } def expireDeadHosts() { - logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.") + logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.") val now = System.currentTimeMillis() val minSeenTime = now - slaveTimeout val toRemove = new HashSet[BlockManagerId] -- cgit v1.2.3 From 7ee824e42ebaa1fc0b0248e0a35021108625ed14 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 28 Jan 2013 21:48:32 -0800 Subject: Units from ms -> s --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 7ba1f3430a..b8336d9d06 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -413,10 +413,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def markStageAsFinished(stage: Stage) = { val serviceTime = stage.submissionTime match { - case Some(t) => (System.currentTimeMillis() - t).toString + case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0) case _ => "Unkown" } - logInfo("%s (%s) finished in %s ms".format(stage, stage.origin, serviceTime)) + logInfo("%s (%s) finished in %s s".format(stage, stage.origin, serviceTime)) running -= stage } event.reason match { -- cgit v1.2.3 From b45857c965219e2d26f35adb2ea3a2b831fdb77f Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 28 Jan 2013 23:56:56 -0600 Subject: Add RDD.toDebugString. Original idea by Nathan Kronenfeld. --- core/src/main/scala/spark/RDD.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 0d3857f9dd..172431c31a 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -638,4 +638,14 @@ abstract class RDD[T: ClassManifest]( protected[spark] def clearDependencies() { dependencies_ = null } + + /** A description of this RDD and its recursive dependencies for debugging. */ + def toDebugString(): String = { + def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = { + Seq(prefix + rdd) ++ rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) + } + debugString(this).mkString("\n") + } + + override def toString() = "%s[%d] at %s".format(getClass.getSimpleName, id, origin) } -- cgit v1.2.3 From 951cfd9ba2a9239a777f156f10af820e9df49606 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 29 Jan 2013 00:02:17 -0600 Subject: Add JavaRDDLike.toDebugString(). --- core/src/main/scala/spark/api/java/JavaRDDLike.scala | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'core') diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 4c95c989b5..44f778e5c2 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -330,4 +330,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround case _ => Optional.absent() } } + + /** A description of this RDD and its recursive dependencies for debugging. */ + def toDebugString(): String = { + rdd.toDebugString() + } } -- cgit v1.2.3 From 3cda14af3fea97c2372c7335505e9dad7e0dd117 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 29 Jan 2013 00:12:31 -0600 Subject: Add number of splits. --- core/src/main/scala/spark/RDD.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 172431c31a..39bacd2afb 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -642,7 +642,8 @@ abstract class RDD[T: ClassManifest]( /** A description of this RDD and its recursive dependencies for debugging. */ def toDebugString(): String = { def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = { - Seq(prefix + rdd) ++ rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) + Seq(prefix + rdd + " (" + rdd.splits.size + " splits)") ++ + rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) } debugString(this).mkString("\n") } -- cgit v1.2.3 From cbf72bffa5874319c7ee7117a073e9d01fa51585 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 29 Jan 2013 00:20:36 -0600 Subject: Include name, if set, in RDD.toString(). --- core/src/main/scala/spark/RDD.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 39bacd2afb..a23441483e 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -648,5 +648,10 @@ abstract class RDD[T: ClassManifest]( debugString(this).mkString("\n") } - override def toString() = "%s[%d] at %s".format(getClass.getSimpleName, id, origin) + override def toString(): String = "%s%s[%d] at %s".format( + Option(name).map(_ + " ").getOrElse(""), + getClass.getSimpleName, + id, + origin) + } -- cgit v1.2.3 From 64ba6a8c2c5f46e6de6deb6a6fd576a55cb3b198 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 28 Jan 2013 22:30:12 -0800 Subject: Simplify checkpointing code and RDD class a little: - RDD's getDependencies and getSplits methods are now guaranteed to be called only once, so subclasses can safely do computation in there without worrying about caching the results. - The management of a "splits_" variable that is cleared out when we checkpoint an RDD is now done in the RDD class. - A few of the RDD subclasses are simpler. - CheckpointRDD's compute() method no longer assumes that it is given a CheckpointRDDSplit -- it can work just as well on a split from the original RDD, because it only looks at its index. This is important because things like UnionRDD and ZippedRDD remember the parent's splits as part of their own and wouldn't work on checkpointed parents. - RDD.iterator can now reuse cached data if an RDD is computed before it is checkpointed. It seems like it wouldn't do this before (it always called iterator() on the CheckpointRDD, which read from HDFS). --- core/src/main/scala/spark/CacheManager.scala | 6 +- core/src/main/scala/spark/PairRDDFunctions.scala | 4 +- core/src/main/scala/spark/RDD.scala | 130 ++++++++++++--------- core/src/main/scala/spark/RDDCheckpointData.scala | 19 +-- .../main/scala/spark/api/java/JavaRDDLike.scala | 2 +- core/src/main/scala/spark/rdd/CartesianRDD.scala | 12 +- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 61 +++++----- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 14 +-- core/src/main/scala/spark/rdd/MappedRDD.scala | 6 +- .../main/scala/spark/rdd/PartitionPruningRDD.scala | 13 +-- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 8 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 14 +-- core/src/main/scala/spark/rdd/ZippedRDD.scala | 7 +- .../main/scala/spark/util/MetadataCleaner.scala | 4 +- core/src/test/scala/spark/CheckpointSuite.scala | 21 ++-- 15 files changed, 153 insertions(+), 168 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala index a0b53fd9d6..711435c333 100644 --- a/core/src/main/scala/spark/CacheManager.scala +++ b/core/src/main/scala/spark/CacheManager.scala @@ -10,9 +10,9 @@ import spark.storage.{BlockManager, StorageLevel} private[spark] class CacheManager(blockManager: BlockManager) extends Logging { private val loading = new HashSet[String] - /** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */ + /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) - : Iterator[T] = { + : Iterator[T] = { val key = "rdd_%d_%d".format(rdd.id, split.index) logInfo("Cache key is " + key) blockManager.get(key) match { @@ -50,7 +50,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { // If we got here, we have to load the split val elements = new ArrayBuffer[Any] logInfo("Computing partition " + split) - elements ++= rdd.compute(split, context) + elements ++= rdd.computeOrReadCheckpoint(split, context) // Try to put this block in the blockManager blockManager.put(key, elements, storageLevel, true) return elements.iterator.asInstanceOf[Iterator[T]] diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 53b051f1c5..231e23a7de 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -649,9 +649,7 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( } private[spark] -class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) - extends RDD[(K, U)](prev) { - +class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev) { override def getSplits = firstParent[(K, V)].splits override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split, context: TaskContext) = diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 0d3857f9dd..dbad6d4c83 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -1,27 +1,17 @@ package spark -import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream} import java.net.URL import java.util.{Date, Random} import java.util.{HashMap => JHashMap} -import java.util.concurrent.atomic.AtomicLong import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import org.apache.hadoop.fs.Path import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapred.FileOutputCommitter -import org.apache.hadoop.mapred.HadoopWriter -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputCommitter -import org.apache.hadoop.mapred.OutputFormat -import org.apache.hadoop.mapred.SequenceFileOutputFormat import org.apache.hadoop.mapred.TextOutputFormat import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} @@ -30,7 +20,6 @@ import spark.partial.BoundedDouble import spark.partial.CountEvaluator import spark.partial.GroupedCountEvaluator import spark.partial.PartialResult -import spark.rdd.BlockRDD import spark.rdd.CartesianRDD import spark.rdd.FilteredRDD import spark.rdd.FlatMappedRDD @@ -73,11 +62,11 @@ import SparkContext._ * on RDD internals. */ abstract class RDD[T: ClassManifest]( - @transient var sc: SparkContext, - var dependencies_ : List[Dependency[_]] + @transient private var sc: SparkContext, + @transient private var deps: Seq[Dependency[_]] ) extends Serializable with Logging { - + /** Construct an RDD with just a one-to-one dependency on one parent */ def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) @@ -85,25 +74,27 @@ abstract class RDD[T: ClassManifest]( // Methods that should be implemented by subclasses of RDD // ======================================================================= - /** Function for computing a given partition. */ + /** Implemented by subclasses to compute a given partition. */ def compute(split: Split, context: TaskContext): Iterator[T] - /** Set of partitions in this RDD. */ - protected def getSplits(): Array[Split] + /** + * Implemented by subclasses to return the set of partitions in this RDD. This method will only + * be called once, so it is safe to implement a time-consuming computation in it. + */ + protected def getSplits: Array[Split] - /** How this RDD depends on any parent RDDs. */ - protected def getDependencies(): List[Dependency[_]] = dependencies_ + /** + * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only + * be called once, so it is safe to implement a time-consuming computation in it. + */ + protected def getDependencies: Seq[Dependency[_]] = deps - /** A friendly name for this RDD */ - var name: String = null - /** Optionally overridden by subclasses to specify placement preferences. */ protected def getPreferredLocations(split: Split): Seq[String] = Nil /** Optionally overridden by subclasses to specify how they are partitioned. */ val partitioner: Option[Partitioner] = None - // ======================================================================= // Methods and fields available on all RDDs // ======================================================================= @@ -111,13 +102,16 @@ abstract class RDD[T: ClassManifest]( /** A unique ID for this RDD (within its SparkContext). */ val id = sc.newRddId() + /** A friendly name for this RDD */ + var name: String = null + /** Assign a name to this RDD */ def setName(_name: String) = { name = _name this } - /** + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. */ @@ -142,15 +136,24 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel + // Our dependencies and splits will be gotten by calling subclass's methods below, and will + // be overwritten when we're checkpointed + private var dependencies_ : Seq[Dependency[_]] = null + @transient private var splits_ : Array[Split] = null + + /** An Option holding our checkpoint RDD, if we are checkpointed */ + private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) + /** - * Get the preferred location of a split, taking into account whether the + * Get the list of dependencies of this RDD, taking into account whether the * RDD is checkpointed or not. */ - final def preferredLocations(split: Split): Seq[String] = { - if (isCheckpointed) { - checkpointData.get.getPreferredLocations(split) - } else { - getPreferredLocations(split) + final def dependencies: Seq[Dependency[_]] = { + checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse { + if (dependencies_ == null) { + dependencies_ = getDependencies + } + dependencies_ } } @@ -159,22 +162,21 @@ abstract class RDD[T: ClassManifest]( * RDD is checkpointed or not. */ final def splits: Array[Split] = { - if (isCheckpointed) { - checkpointData.get.getSplits - } else { - getSplits + checkpointRDD.map(_.splits).getOrElse { + if (splits_ == null) { + splits_ = getSplits + } + splits_ } } /** - * Get the list of dependencies of this RDD, taking into account whether the + * Get the preferred location of a split, taking into account whether the * RDD is checkpointed or not. */ - final def dependencies: List[Dependency[_]] = { - if (isCheckpointed) { - dependencies_ - } else { - getDependencies + final def preferredLocations(split: Split): Seq[String] = { + checkpointRDD.map(_.getPreferredLocations(split)).getOrElse { + getPreferredLocations(split) } } @@ -184,10 +186,19 @@ abstract class RDD[T: ClassManifest]( * subclasses of RDD. */ final def iterator(split: Split, context: TaskContext): Iterator[T] = { - if (isCheckpointed) { - checkpointData.get.iterator(split, context) - } else if (storageLevel != StorageLevel.NONE) { + if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) + } else { + computeOrReadCheckpoint(split, context) + } + } + + /** + * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. + */ + private[spark] def computeOrReadCheckpoint(split: Split, context: TaskContext): Iterator[T] = { + if (isCheckpointed) { + firstParent[T].iterator(split, context) } else { compute(split, context) } @@ -578,15 +589,15 @@ abstract class RDD[T: ClassManifest]( /** * Return whether this RDD has been checkpointed or not */ - def isCheckpointed(): Boolean = { - if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false + def isCheckpointed: Boolean = { + checkpointData.map(_.isCheckpointed).getOrElse(false) } /** * Gets the name of the file to which this RDD was checkpointed */ - def getCheckpointFile(): Option[String] = { - if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None + def getCheckpointFile: Option[String] = { + checkpointData.flatMap(_.getCheckpointFile) } // ======================================================================= @@ -611,31 +622,36 @@ abstract class RDD[T: ClassManifest]( def context = sc /** - * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler + * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler * after a job using this RDD has completed (therefore the RDD has been materialized and * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. */ - protected[spark] def doCheckpoint() { - if (checkpointData.isDefined) checkpointData.get.doCheckpoint() - dependencies.foreach(_.rdd.doCheckpoint()) + private[spark] def doCheckpoint() { + if (checkpointData.isDefined) { + checkpointData.get.doCheckpoint() + } else { + dependencies.foreach(_.rdd.doCheckpoint()) + } } /** - * Changes the dependencies of this RDD from its original parents to the new RDD - * (`newRDD`) created from the checkpoint file. + * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`) + * created from the checkpoint file, and forget its old dependencies and splits. */ - protected[spark] def changeDependencies(newRDD: RDD[_]) { + private[spark] def markCheckpointed(checkpointRDD: RDD[_]) { clearDependencies() - dependencies_ = List(new OneToOneDependency(newRDD)) + dependencies_ = null + splits_ = null + deps = null // Forget the constructor argument for dependencies too } /** * Clears the dependencies of this RDD. This method must ensure that all references * to the original parent RDDs is removed to enable the parent RDDs to be garbage * collected. Subclasses of RDD may override this method for implementing their own cleaning - * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + * logic. See [[spark.rdd.UnionRDD]] for an example. */ - protected[spark] def clearDependencies() { + protected def clearDependencies() { dependencies_ = null } } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index 18df530b7d..a4a4ebaf53 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -20,7 +20,7 @@ private[spark] object CheckpointState extends Enumeration { * of the checkpointed RDD. */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) -extends Logging with Serializable { + extends Logging with Serializable { import CheckpointState._ @@ -31,7 +31,7 @@ extends Logging with Serializable { @transient var cpFile: Option[String] = None // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. - @transient var cpRDD: Option[RDD[T]] = None + var cpRDD: Option[RDD[T]] = None // Mark the RDD for checkpointing def markForCheckpoint() { @@ -41,12 +41,12 @@ extends Logging with Serializable { } // Is the RDD already checkpointed - def isCheckpointed(): Boolean = { + def isCheckpointed: Boolean = { RDDCheckpointData.synchronized { cpState == Checkpointed } } // Get the file to which this RDD was checkpointed to as an Option - def getCheckpointFile(): Option[String] = { + def getCheckpointFile: Option[String] = { RDDCheckpointData.synchronized { cpFile } } @@ -71,7 +71,7 @@ extends Logging with Serializable { RDDCheckpointData.synchronized { cpFile = Some(path) cpRDD = Some(newRDD) - rdd.changeDependencies(newRDD) + rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and splits cpState = Checkpointed RDDCheckpointData.clearTaskCaches() logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) @@ -79,7 +79,7 @@ extends Logging with Serializable { } // Get preferred location of a split after checkpointing - def getPreferredLocations(split: Split) = { + def getPreferredLocations(split: Split): Seq[String] = { RDDCheckpointData.synchronized { cpRDD.get.preferredLocations(split) } @@ -91,9 +91,10 @@ extends Logging with Serializable { } } - // Get iterator. This is called at the worker nodes. - def iterator(split: Split, context: TaskContext): Iterator[T] = { - rdd.firstParent[T].iterator(split, context) + def checkpointRDD: Option[RDD[T]] = { + RDDCheckpointData.synchronized { + cpRDD + } } } diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 4c95c989b5..46fd8fe85e 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -319,7 +319,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround /** * Return whether this RDD has been checkpointed or not */ - def isCheckpointed(): Boolean = rdd.isCheckpointed() + def isCheckpointed: Boolean = rdd.isCheckpointed /** * Gets the name of the file to which this RDD was checkpointed diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 453d410ad4..0f9ca06531 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,7 +1,7 @@ package spark.rdd import java.io.{ObjectOutputStream, IOException} -import spark.{OneToOneDependency, NarrowDependency, RDD, SparkContext, Split, TaskContext} +import spark._ private[spark] @@ -35,7 +35,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( val numSplitsInRdd2 = rdd2.splits.size - @transient var splits_ = { + override def getSplits: Array[Split] = { // create the cross product split val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { @@ -45,8 +45,6 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( array } - override def getSplits = splits_ - override def getPreferredLocations(split: Split) = { val currSplit = split.asInstanceOf[CartesianSplit] rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) @@ -58,7 +56,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) } - var deps_ = List( + override def getDependencies: Seq[Dependency[_]] = List( new NarrowDependency(rdd1) { def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) }, @@ -67,11 +65,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( } ) - override def getDependencies = deps_ - override def clearDependencies() { - deps_ = Nil - splits_ = null rdd1 = null rdd2 = null } diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 6f00f6ac73..96b593ba7c 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -9,23 +9,26 @@ import org.apache.hadoop.fs.Path import java.io.{File, IOException, EOFException} import java.text.NumberFormat -private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split { - override val index: Int = idx -} +private[spark] class CheckpointRDDSplit(val index: Int) extends Split {} /** * This RDD represents a RDD checkpoint file (similar to HadoopRDD). */ private[spark] -class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) +class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { - @transient val path = new Path(checkpointPath) - @transient val fs = path.getFileSystem(new Configuration()) + @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) @transient val splits_ : Array[Split] = { - val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted - splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray + val dirContents = fs.listStatus(new Path(checkpointPath)) + val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted + val numSplits = splitFiles.size + if (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || + !splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1))) { + throw new SparkException("Invalid checkpoint directory: " + checkpointPath) + } + Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i)) } checkpointData = Some(new RDDCheckpointData[T](this)) @@ -34,36 +37,34 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) override def getSplits = splits_ override def getPreferredLocations(split: Split): Seq[String] = { - val status = fs.getFileStatus(path) + val status = fs.getFileStatus(new Path(checkpointPath)) val locations = fs.getFileBlockLocations(status, 0, status.getLen) - locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost") + locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") } override def compute(split: Split, context: TaskContext): Iterator[T] = { - CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile, context) + val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)) + CheckpointRDD.readFromFile(file, context) } override def checkpoint() { - // Do nothing. Hadoop RDD should not be checkpointed. + // Do nothing. CheckpointRDD should not be checkpointed. } } private[spark] object CheckpointRDD extends Logging { - def splitIdToFileName(splitId: Int): String = { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - "part-" + numfmt.format(splitId) + def splitIdToFile(splitId: Int): String = { + "part-%05d".format(splitId) } - def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) { + def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { val outputDir = new Path(path) val fs = outputDir.getFileSystem(new Configuration()) - val finalOutputName = splitIdToFileName(context.splitId) + val finalOutputName = splitIdToFile(ctx.splitId) val finalOutputPath = new Path(outputDir, finalOutputName) - val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId) + val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) if (fs.exists(tempOutputPath)) { throw new IOException("Checkpoint failed: temporary path " + @@ -83,22 +84,22 @@ private[spark] object CheckpointRDD extends Logging { serializeStream.close() if (!fs.rename(tempOutputPath, finalOutputPath)) { - if (!fs.delete(finalOutputPath, true)) { - throw new IOException("Checkpoint failed: failed to delete earlier output of task " - + context.attemptId) - } - if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.exists(finalOutputPath)) { + fs.delete(tempOutputPath, false) throw new IOException("Checkpoint failed: failed to save output of task: " - + context.attemptId) + + ctx.attemptId + " and final output path does not exist") + } else { + // Some other copy of this task must've finished before us and renamed it + logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") + fs.delete(tempOutputPath, false) } } } - def readFromFile[T](path: String, context: TaskContext): Iterator[T] = { - val inputPath = new Path(path) - val fs = inputPath.getFileSystem(new Configuration()) + def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { + val fs = path.getFileSystem(new Configuration()) val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val fileInputStream = fs.open(inputPath, bufferSize) + val fileInputStream = fs.open(path, bufferSize) val serializer = SparkEnv.get.serializer.newInstance() val deserializeStream = serializer.deserializeStream(fileInputStream) diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 167755bbba..4c57434b65 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -27,11 +27,11 @@ private[spark] case class CoalescedRDDSplit( * or to avoid having a large number of small tasks when processing a directory with many files. */ class CoalescedRDD[T: ClassManifest]( - var prev: RDD[T], + @transient var prev: RDD[T], maxPartitions: Int) - extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs + extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies - @transient var splits_ : Array[Split] = { + override def getSplits: Array[Split] = { val prevSplits = prev.splits if (prevSplits.length < maxPartitions) { prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) } @@ -44,26 +44,20 @@ class CoalescedRDD[T: ClassManifest]( } } - override def getSplits = splits_ - override def compute(split: Split, context: TaskContext): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit => firstParent[T].iterator(parentSplit, context) } } - var deps_ : List[Dependency[_]] = List( + override def getDependencies: Seq[Dependency[_]] = List( new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices } ) - override def getDependencies() = deps_ - override def clearDependencies() { - deps_ = Nil - splits_ = null prev = null } } diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index c6ceb272cd..5466c9c657 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -3,13 +3,11 @@ package spark.rdd import spark.{RDD, Split, TaskContext} private[spark] -class MappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: T => U) +class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U) extends RDD[U](prev) { override def getSplits = firstParent[T].splits override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context).map(f) -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 97dd37950e..b8482338c6 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -7,23 +7,18 @@ import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. + * + * TODO: This currently doesn't give partition IDs properly! */ class PartitionPruningRDD[T: ClassManifest]( @transient prev: RDD[T], @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - @transient - var partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions - override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) - override protected def getSplits = partitions_ + override protected def getSplits = + getDependencies.head.asInstanceOf[PruneDependency[T]].partitions override val partitioner = firstParent[T].partitioner - - override def clearDependencies() { - super.clearDependencies() - partitions_ = null - } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 28ff19876d..d396478673 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -22,16 +22,10 @@ class ShuffledRDD[K, V]( override val partitioner = Some(part) - @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) - - override def getSplits = splits_ + override def getSplits = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } - - override def clearDependencies() { - splits_ = null - } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 82f0a44ecd..26a2d511f2 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -26,9 +26,9 @@ private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIn class UnionRDD[T: ClassManifest]( sc: SparkContext, @transient var rdds: Seq[RDD[T]]) - extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs + extends RDD[T](sc, Nil) { // Nil since we implement getDependencies - @transient var splits_ : Array[Split] = { + override def getSplits: Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { @@ -38,20 +38,16 @@ class UnionRDD[T: ClassManifest]( array } - override def getSplits = splits_ - - @transient var deps_ = { + override def getDependencies: Seq[Dependency[_]] = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) pos += rdd.splits.size } - deps.toList + deps } - override def getDependencies = deps_ - override def compute(s: Split, context: TaskContext): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator(context) @@ -59,8 +55,6 @@ class UnionRDD[T: ClassManifest]( s.asInstanceOf[UnionSplit[T]].preferredLocations() override def clearDependencies() { - deps_ = null - splits_ = null rdds = null } } diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index d950b06c85..e5df6d8c72 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -32,9 +32,7 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) with Serializable { - // TODO: FIX THIS. - - @transient var splits_ : Array[Split] = { + override def getSplits: Array[Split] = { if (rdd1.splits.size != rdd2.splits.size) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } @@ -45,8 +43,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( array } - override def getSplits = splits_ - override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = { val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context)) @@ -58,7 +54,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( } override def clearDependencies() { - splits_ = null rdd1 = null rdd2 = null } diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 6cf93a9b17..eaff7ae581 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -26,8 +26,8 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging if (delaySeconds > 0) { logDebug( - "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and " - + "period of " + periodSeconds + " secs") + "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " + + "and period of " + periodSeconds + " secs") timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 33c317720c..0b74607fb8 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. // Note that this test is very specific to the current implementation of CartesianRDD. val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint // checkpoint that MappedRDD + ones.checkpoint() // checkpoint that MappedRDD val cartesian = new CartesianRDD(sc, ones, ones) val splitBeforeCheckpoint = serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit]) @@ -125,7 +125,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. // Note that this test is very specific to the current implementation of CoalescedRDDSplits val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint // checkpoint that MappedRDD + ones.checkpoint() // checkpoint that MappedRDD val coalesced = new CoalescedRDD(ones, 2) val splitBeforeCheckpoint = serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit]) @@ -160,7 +160,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { // so only the RDD will reduce in serialized size, not the splits. testParentCheckpointing( rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) - } /** @@ -176,7 +175,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testRDDSplitSize: Boolean = false ) { // Generate the final RDD using given RDD operation - val baseRDD = generateLongLineageRDD + val baseRDD = generateLongLineageRDD() val operatedRDD = op(baseRDD) val parentRDD = operatedRDD.dependencies.headOption.orNull val rddType = operatedRDD.getClass.getSimpleName @@ -245,12 +244,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testRDDSplitSize: Boolean ) { // Generate the final RDD using given RDD operation - val baseRDD = generateLongLineageRDD + val baseRDD = generateLongLineageRDD() val operatedRDD = op(baseRDD) val parentRDD = operatedRDD.dependencies.head.rdd val rddType = operatedRDD.getClass.getSimpleName val parentRDDType = parentRDD.getClass.getSimpleName + // Get the splits and dependencies of the parent in case they're lazily computed + parentRDD.dependencies + parentRDD.splits + // Find serialized sizes before and after the checkpoint val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one @@ -267,7 +270,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { if (testRDDSize) { assert( rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType + + "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" ) } @@ -318,10 +321,12 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } /** - * Get serialized sizes of the RDD and its splits + * Get serialized sizes of the RDD and its splits, in order to test whether the size shrinks + * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. */ def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { - (Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size) + (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length, + Utils.serialize(rdd.splits).length) } /** -- cgit v1.2.3 From a34096a76de9d07518ce33111ad43b88049c1ac2 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 28 Jan 2013 22:40:16 -0800 Subject: Add easymock to POMs --- core/pom.xml | 5 +++++ pom.xml | 6 ++++++ 2 files changed, 11 insertions(+) (limited to 'core') diff --git a/core/pom.xml b/core/pom.xml index 862d3ec37a..a2b9b726a6 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -98,6 +98,11 @@ scalacheck_${scala.version} test + + org.easymock + easymock + test + com.novocode junit-interface diff --git a/pom.xml b/pom.xml index 3ea989a082..4a4ff560e7 100644 --- a/pom.xml +++ b/pom.xml @@ -273,6 +273,12 @@ 1.8 test + + org.easymock + easymock + 3.1 + test + org.scalacheck scalacheck_${scala.version} -- cgit v1.2.3 From 16a0789e10d2ac714e7c623b026c4a58ca9678d6 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 29 Jan 2013 17:09:53 -0800 Subject: Remember ConnectionManagerId used to initiate SendingConnections. This prevents ConnectionManager from getting confused if a machine has multiple host names and the one getHostName() finds happens not to be the one that was passed from, e.g., the BlockManagerMaster. --- core/src/main/scala/spark/network/Connection.scala | 15 +++++++++++---- core/src/main/scala/spark/network/ConnectionManager.scala | 3 ++- 2 files changed, 13 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index c193bf7c8d..cd5b7d57f3 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -12,7 +12,14 @@ import java.net._ private[spark] -abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging { +abstract class Connection(val channel: SocketChannel, val selector: Selector, + val remoteConnectionManagerId: ConnectionManagerId) extends Logging { + def this(channel_ : SocketChannel, selector_ : Selector) = { + this(channel_, selector_, + ConnectionManagerId.fromSocketAddress( + channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + )) + } channel.configureBlocking(false) channel.socket.setTcpNoDelay(true) @@ -25,7 +32,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() - val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress) def key() = channel.keyFor(selector) @@ -103,8 +109,9 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex } -private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector) -extends Connection(SocketChannel.open, selector_) { +private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, + remoteId_ : ConnectionManagerId) +extends Connection(SocketChannel.open, selector_, remoteId_) { class Outbox(fair: Int = 0) { val messages = new Queue[Message]() diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 2ecd14f536..c7f226044d 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -299,7 +299,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector)) + val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, + new SendingConnection(inetSocketAddress, selector, connectionManagerId)) newConnection } val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) -- cgit v1.2.3 From a3d14c0404d6b28433784f84086a29ecc0045a12 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 28 Jan 2013 22:41:08 -0800 Subject: Refactoring to DAGScheduler to aid testing --- core/src/main/scala/spark/SparkContext.scala | 1 + .../main/scala/spark/scheduler/DAGScheduler.scala | 29 +++++++++++++--------- 2 files changed, 18 insertions(+), 12 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index dc9b8688b3..6ae04f4a44 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -187,6 +187,7 @@ class SparkContext( taskScheduler.start() private var dagScheduler = new DAGScheduler(taskScheduler) + dagScheduler.start() /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = { diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b130be6a38..9655961162 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -23,7 +23,14 @@ import util.{MetadataCleaner, TimeStampedHashMap} * and to report fetch failures (the submitTasks method, and code to add CompletionEvents). */ private[spark] -class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging { +class DAGScheduler(taskSched: TaskScheduler, + mapOutputTracker: MapOutputTracker, + blockManagerMaster: BlockManagerMaster, + env: SparkEnv) + extends TaskSchedulerListener with Logging { + def this(taskSched: TaskScheduler) { + this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get) + } taskSched.setListener(this) // Called by TaskScheduler to report task completions or failures. @@ -66,10 +73,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with var cacheLocs = new HashMap[Int, Array[List[String]]] - val env = SparkEnv.get - val mapOutputTracker = env.mapOutputTracker - val blockManagerMaster = env.blockManager.master - // For tracking failed nodes, we use the MapOutputTracker's generation number, which is // sent with every task. When we detect a node failing, we note the current generation number // and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask @@ -90,12 +93,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) // Start a thread to run the DAGScheduler event loop - new Thread("DAGScheduler") { - setDaemon(true) - override def run() { - DAGScheduler.this.run() - } - }.start() + def start() { + new Thread("DAGScheduler") { + setDaemon(true) + override def run() { + DAGScheduler.this.run() + } + }.start() + } def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { if (!cacheLocs.contains(rdd.id)) { @@ -546,7 +551,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) { failedGeneration(execId) = currentGeneration logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration)) - env.blockManager.master.removeExecutor(execId) + blockManagerMaster.removeExecutor(execId) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) -- cgit v1.2.3 From 9eac7d01f0880d1d3d51e922ef2566c4ee92989f Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 28 Jan 2013 22:42:35 -0800 Subject: Add DAGScheduler tests. --- .../scala/spark/scheduler/DAGSchedulerSuite.scala | 540 +++++++++++++++++++++ 1 file changed, 540 insertions(+) create mode 100644 core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala (limited to 'core') diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala new file mode 100644 index 0000000000..53f5214d7a --- /dev/null +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -0,0 +1,540 @@ +package spark.scheduler + +import scala.collection.mutable.{Map, HashMap} + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.AsyncAssertions +import org.scalatest.concurrent.TimeLimitedTests +import org.scalatest.mock.EasyMockSugar +import org.scalatest.time.{Span, Seconds} + +import org.easymock.EasyMock._ +import org.easymock.EasyMock +import org.easymock.{IAnswer, IArgumentMatcher} + +import akka.actor.ActorSystem + +import spark.storage.BlockManager +import spark.storage.BlockManagerId +import spark.storage.BlockManagerMaster +import spark.{Dependency, ShuffleDependency, OneToOneDependency} +import spark.FetchFailedException +import spark.MapOutputTracker +import spark.RDD +import spark.SparkContext +import spark.SparkException +import spark.Split +import spark.TaskContext +import spark.TaskEndReason + +import spark.{FetchFailed, Success} + +class DAGSchedulerSuite extends FunSuite + with BeforeAndAfter with EasyMockSugar with TimeLimitedTests + with AsyncAssertions with spark.Logging { + + // If we crash the DAGScheduler thread, our test will probably hang. + override val timeLimit = Span(5, Seconds) + + val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite") + var scheduler: DAGScheduler = null + var w: Waiter = null + val taskScheduler = mock[TaskScheduler] + val blockManagerMaster = mock[BlockManagerMaster] + var mapOutputTracker: MapOutputTracker = null + var schedulerThread: Thread = null + var schedulerException: Throwable = null + val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher] + val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] + + implicit val mocks = MockObjects(taskScheduler, blockManagerMaster) + + def makeBlockManagerId(host: String): BlockManagerId = + BlockManagerId("exec-" + host, host, 12345) + + def resetExpecting(f: => Unit) { + reset(taskScheduler) + reset(blockManagerMaster) + expecting(f) + } + + before { + taskSetMatchers.clear() + cacheLocations.clear() + val actorSystem = ActorSystem("test") + mapOutputTracker = new MapOutputTracker(actorSystem, true) + resetExpecting { + taskScheduler.setListener(anyObject()) + } + whenExecuting { + scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) + } + w = new Waiter + schedulerException = null + schedulerThread = new Thread("DAGScheduler under test") { + override def run() { + try { + scheduler.run() + } catch { + case t: Throwable => + logError("Got exception in DAGScheduler: ", t) + schedulerException = t + } finally { + w.dismiss() + } + } + } + schedulerThread.start + logInfo("finished before") + } + + after { + logInfo("started after") + resetExpecting { + taskScheduler.stop() + } + whenExecuting { + scheduler.stop + schedulerThread.join + } + w.await() + if (schedulerException != null) { + throw new Exception("Exception caught from scheduler thread", schedulerException) + } + } + + // Type of RDD we use for testing. Note that we should never call the real RDD compute methods. + // This is a pair RDD type so it can always be used in ShuffleDependencies. + type MyRDD = RDD[(Int, Int)] + + def makeRdd( + numSplits: Int, + dependencies: List[Dependency[_]], + locations: Seq[Seq[String]] = Nil + ): MyRDD = { + val maxSplit = numSplits - 1 + return new MyRDD(sc, dependencies) { + override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + override def getSplits() = (0 to maxSplit).map(i => new Split { + override def index = i + }).toArray + override def getPreferredLocations(split: Split): Seq[String] = + if (locations.isDefinedAt(split.index)) + locations(split.index) + else + Nil + override def toString: String = "DAGSchedulerSuiteRDD " + id + } + } + + def taskSetForRdd(rdd: MyRDD): TaskSet = { + val matcher = taskSetMatchers.getOrElseUpdate(rdd, + new IArgumentMatcher { + override def matches(actual: Any): Boolean = { + val taskSet = actual.asInstanceOf[TaskSet] + taskSet.tasks(0) match { + case rt: ResultTask[_, _] => rt.rdd.id == rdd.id + case smt: ShuffleMapTask => smt.rdd.id == rdd.id + case _ => false + } + } + override def appendTo(buf: StringBuffer) { + buf.append("taskSetForRdd(" + rdd + ")") + } + }) + EasyMock.reportMatcher(matcher) + return null + } + + def expectGetLocations(): Unit = { + EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])). + andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] { + override def answer(): Seq[Seq[BlockManagerId]] = { + val blocks = getCurrentArguments()(0).asInstanceOf[Array[String]] + return blocks.map { name => + val pieces = name.split("_") + if (pieces(0) == "rdd") { + val key = pieces(1).toInt -> pieces(2).toInt + if (cacheLocations.contains(key)) { + cacheLocations(key) + } else { + Seq[BlockManagerId]() + } + } else { + Seq[BlockManagerId]() + } + }.toSeq + } + }).anyTimes() + } + + def expectStageAnd(rdd: MyRDD, results: Seq[(TaskEndReason, Any)], + preferredLocations: Option[Seq[Seq[String]]] = None)(afterSubmit: TaskSet => Unit) { + // TODO: Remember which submission + EasyMock.expect(taskScheduler.submitTasks(taskSetForRdd(rdd))).andAnswer(new IAnswer[Unit] { + override def answer(): Unit = { + val taskSet = getCurrentArguments()(0).asInstanceOf[TaskSet] + for (task <- taskSet.tasks) { + task.generation = mapOutputTracker.getGeneration + } + afterSubmit(taskSet) + preferredLocations match { + case None => + for (taskLocs <- taskSet.tasks.map(_.preferredLocations)) { + w { assert(taskLocs.size === 0) } + } + case Some(locations) => + w { assert(locations.size === taskSet.tasks.size) } + for ((expectLocs, taskLocs) <- + taskSet.tasks.map(_.preferredLocations).zip(locations)) { + w { assert(expectLocs === taskLocs) } + } + } + w { assert(taskSet.tasks.size >= results.size)} + for ((result, i) <- results.zipWithIndex) { + if (i < taskSet.tasks.size) { + scheduler.taskEnded(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()) + } + } + } + }) + } + + def expectStage(rdd: MyRDD, results: Seq[(TaskEndReason, Any)], + preferredLocations: Option[Seq[Seq[String]]] = None) { + expectStageAnd(rdd, results, preferredLocations) { _ => } + } + + def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): Array[Int] = { + return scheduler.runJob[(Int, Int), Int]( + rdd, + (context: TaskContext, it: Iterator[(Int, Int)]) => it.next._1.asInstanceOf[Int], + (0 to (rdd.splits.size - 1)), + "test-site", + allowLocal + ) + } + + def makeMapStatus(host: String, reduces: Int): MapStatus = + new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2)) + + test("zero split job") { + val rdd = makeRdd(0, Nil) + resetExpecting { + expectGetLocations() + // deliberately expect no stages to be submitted + } + whenExecuting { + assert(submitRdd(rdd) === Array[Int]()) + } + } + + test("run trivial job") { + val rdd = makeRdd(1, Nil) + resetExpecting { + expectGetLocations() + expectStage(rdd, List( (Success, 42) )) + } + whenExecuting { + assert(submitRdd(rdd) === Array(42)) + } + } + + test("local job") { + val rdd = new MyRDD(sc, Nil) { + override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] = + Array(42 -> 0).iterator + override def getSplits() = Array( new Split { override def index = 0 } ) + override def getPreferredLocations(split: Split) = Nil + override def toString = "DAGSchedulerSuite Local RDD" + } + resetExpecting { + expectGetLocations() + // deliberately expect no stages to be submitted + } + whenExecuting { + assert(submitRdd(rdd, true) === Array(42)) + } + } + + test("run trivial job w/ dependency") { + val baseRdd = makeRdd(1, Nil) + val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) + resetExpecting { + expectGetLocations() + expectStage(finalRdd, List( (Success, 42) )) + } + whenExecuting { + assert(submitRdd(finalRdd) === Array(42)) + } + } + + test("location preferences w/ dependency") { + val baseRdd = makeRdd(1, Nil) + val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) + resetExpecting { + expectGetLocations() + cacheLocations(baseRdd.id -> 0) = + Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) + expectStage(finalRdd, List( (Success, 42) ), + Some(List(Seq("hostA", "hostB")))) + } + whenExecuting { + assert(submitRdd(finalRdd) === Array(42)) + } + } + + test("trivial job failure") { + val rdd = makeRdd(1, Nil) + resetExpecting { + expectGetLocations() + expectStageAnd(rdd, List()) { taskSet => scheduler.taskSetFailed(taskSet, "test failure") } + } + whenExecuting(taskScheduler, blockManagerMaster) { + intercept[SparkException] { submitRdd(rdd) } + } + } + + test("run trivial shuffle") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(1, List(shuffleDep)) + + resetExpecting { + expectGetLocations() + expectStage(shuffleMapRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)) + )) + expectStageAnd(reduceRdd, List( (Success, 42) )) { _ => + w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) } + } + } + whenExecuting { + assert(submitRdd(reduceRdd) === Array(42)) + } + } + + test("run trivial shuffle with fetch failure") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(2, List(shuffleDep)) + + resetExpecting { + expectGetLocations() + expectStage(shuffleMapRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)) + )) + blockManagerMaster.removeExecutor("exec-hostA") + expectStage(reduceRdd, List( + (Success, 42), + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null) + )) + // partial recompute + expectStage(shuffleMapRdd, List( (Success, makeMapStatus("hostA", 1)) )) + expectStageAnd(reduceRdd, List( (Success, 43) )) { _ => + w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostA"), + makeBlockManagerId("hostB"))) } + } + } + whenExecuting { + assert(submitRdd(reduceRdd) === Array(42, 43)) + } + } + + test("ignore late map task completions") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(2, List(shuffleDep)) + + resetExpecting { + expectGetLocations() + expectStageAnd(shuffleMapRdd, List( + (Success, makeMapStatus("hostA", 1)) + )) { taskSet => + val newGeneration = mapOutputTracker.getGeneration + 1 + scheduler.executorLost("exec-hostA") + val noAccum = Map[Long, Any]() + // We rely on the event queue being ordered and increasing the generation number by 1 + // should be ignored for being too old + scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum) + // should work because it's a non-failed host + scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum) + // should be ignored for being too old + scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum) + // should be ignored (not end the stage) because it's too old + scheduler.taskEnded(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum) + taskSet.tasks(1).generation = newGeneration + scheduler.taskEnded(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum) + } + blockManagerMaster.removeExecutor("exec-hostA") + expectStageAnd(reduceRdd, List( + (Success, 42), (Success, 43) + )) { _ => + w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) } + } + } + whenExecuting { + assert(submitRdd(reduceRdd) === Array(42, 43)) + } + } + + test("run trivial shuffle with out-of-band failure") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(1, List(shuffleDep)) + resetExpecting { + expectGetLocations() + blockManagerMaster.removeExecutor("exec-hostA") + expectStageAnd(shuffleMapRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)) + )) { _ => scheduler.executorLost("exec-hostA") } + expectStage(shuffleMapRdd, List( + (Success, makeMapStatus("hostC", 1)) + )) + expectStageAnd(reduceRdd, List( (Success, 42) )) { _ => + w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostC"), + makeBlockManagerId("hostB"))) } + } + } + whenExecuting { + assert(submitRdd(reduceRdd) === Array(42)) + } + } + + test("recursive shuffle failures") { + val shuffleOneRdd = makeRdd(2, Nil) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = makeRdd(1, List(shuffleDepTwo)) + + resetExpecting { + expectGetLocations() + expectStage(shuffleOneRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)) + )) + expectStage(shuffleTwoRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostC", 1)) + )) + blockManagerMaster.removeExecutor("exec-hostA") + expectStage(finalRdd, List( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) + )) + // triggers a partial recompute of the first stage, then the second + expectStage(shuffleOneRdd, List( + (Success, makeMapStatus("hostA", 1)) + )) + expectStage(shuffleTwoRdd, List( + (Success, makeMapStatus("hostA", 1)) + )) + expectStage(finalRdd, List( + (Success, 42) + )) + } + whenExecuting { + assert(submitRdd(finalRdd) === Array(42)) + } + } + + test("cached post-shuffle") { + val shuffleOneRdd = makeRdd(2, Nil) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = makeRdd(1, List(shuffleDepTwo)) + + resetExpecting { + expectGetLocations() + expectStage(shuffleOneRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)) + )) + expectStageAnd(shuffleTwoRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostC", 1)) + )){ _ => + cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) + cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) + } + blockManagerMaster.removeExecutor("exec-hostA") + expectStage(finalRdd, List( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) + )) + // since we have a cached copy of the missing split of shuffleTwoRdd, we shouldn't + // immediately try to rerun shuffleOneRdd: + expectStage(shuffleTwoRdd, List( + (Success, makeMapStatus("hostD", 1)) + ), Some(Seq(List("hostD")))) + expectStage(finalRdd, List( + (Success, 42) + )) + } + whenExecuting { + assert(submitRdd(finalRdd) === Array(42)) + } + } + + test("cached post-shuffle but fails") { + val shuffleOneRdd = makeRdd(2, Nil) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = makeRdd(1, List(shuffleDepTwo)) + + resetExpecting { + expectGetLocations() + expectStage(shuffleOneRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)) + )) + expectStageAnd(shuffleTwoRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostC", 1)) + )){ _ => + cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) + cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) + } + blockManagerMaster.removeExecutor("exec-hostA") + expectStage(finalRdd, List( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) + )) + // since we have a cached copy of the missing split of shuffleTwoRdd, we shouldn't + // immediately try to rerun shuffleOneRdd: + expectStageAnd(shuffleTwoRdd, List( + (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null) + ), Some(Seq(List("hostD")))) { _ => + w { + intercept[FetchFailedException]{ + mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0) + } + } + cacheLocations.remove(shuffleTwoRdd.id -> 0) + } + // after that fetch failure, we should refetch the cache locations and try to recompute + // the whole chain. Note that we will ignore that a fetch failure previously occured on + // this host. + expectStage(shuffleOneRdd, List( (Success, makeMapStatus("hostA", 1)) )) + expectStage(shuffleTwoRdd, List( (Success, makeMapStatus("hostA", 1)) )) + expectStage(finalRdd, List( (Success, 42) )) + } + whenExecuting { + assert(submitRdd(finalRdd) === Array(42)) + } + } +} + -- cgit v1.2.3 From 4bf3d7ea1252454ca584a3dabf26bdeab4069409 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 29 Jan 2013 19:05:45 -0800 Subject: Clear spark.master.port to cleanup for other tests --- core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 53f5214d7a..6c577c2685 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -102,6 +102,7 @@ class DAGSchedulerSuite extends FunSuite if (schedulerException != null) { throw new Exception("Exception caught from scheduler thread", schedulerException) } + System.clearProperty("spark.master.port") } // Type of RDD we use for testing. Note that we should never call the real RDD compute methods. -- cgit v1.2.3 From 178b89204c9dbee36886e757ddaafbd079672f4a Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 30 Jan 2013 09:19:55 -0800 Subject: Refactor DAGScheduler more to allow testing without a separate thread. --- .../main/scala/spark/scheduler/DAGScheduler.scala | 176 +++++++++++++-------- 1 file changed, 111 insertions(+), 65 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 9655961162..6892509ed1 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -23,11 +23,13 @@ import util.{MetadataCleaner, TimeStampedHashMap} * and to report fetch failures (the submitTasks method, and code to add CompletionEvents). */ private[spark] -class DAGScheduler(taskSched: TaskScheduler, - mapOutputTracker: MapOutputTracker, - blockManagerMaster: BlockManagerMaster, - env: SparkEnv) - extends TaskSchedulerListener with Logging { +class DAGScheduler( + taskSched: TaskScheduler, + mapOutputTracker: MapOutputTracker, + blockManagerMaster: BlockManagerMaster, + env: SparkEnv) + extends TaskSchedulerListener with Logging { + def this(taskSched: TaskScheduler) { this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get) } @@ -203,6 +205,27 @@ class DAGScheduler(taskSched: TaskScheduler, missing.toList } + /** Returns (and does not) submit a JobSubmitted event suitable to run a given job, and + * a JobWaiter whose getResult() method will return the result of the job when it is complete. + * + * The job is assumed to have at least one partition; zero partition jobs should be handled + * without a JobSubmitted event. + */ + private[scheduler] def prepareJob[T, U: ClassManifest]( + finalRdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + callSite: String, + allowLocal: Boolean) + : (JobSubmitted, JobWaiter) = + { + assert(partitions.size > 0) + val waiter = new JobWaiter(partitions.size) + val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] + val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter) + return (toSubmit, waiter) + } + def runJob[T, U: ClassManifest]( finalRdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -214,9 +237,8 @@ class DAGScheduler(taskSched: TaskScheduler, if (partitions.size == 0) { return new Array[U](0) } - val waiter = new JobWaiter(partitions.size) - val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)) + val (toSubmit, waiter) = prepareJob(finalRdd, func, partitions, callSite, allowLocal) + eventQueue.put(toSubmit) waiter.getResult() match { case JobSucceeded(results: Seq[_]) => return results.asInstanceOf[Seq[U]].toArray @@ -241,6 +263,81 @@ class DAGScheduler(taskSched: TaskScheduler, return listener.getResult() // Will throw an exception if the job fails } + /** Process one event retrieved from the event queue. + * Returns true if we should stop the event loop. + */ + private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { + event match { + case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) => + val runId = nextRunId.getAndIncrement() + val finalStage = newStage(finalRDD, None, runId) + val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) + clearCacheLocs() + logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + + " output partitions (allowLocal=" + allowLocal + ")") + logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { + // Compute very short actions like first() or take() with no parent stages locally. + runLocally(job) + } else { + activeJobs += job + resultStageToJob(finalStage) = job + submitStage(finalStage) + } + + case ExecutorLost(execId) => + handleExecutorLost(execId) + + case completion: CompletionEvent => + handleTaskCompletion(completion) + + case TaskSetFailed(taskSet, reason) => + abortStage(idToStage(taskSet.stageId), reason) + + case StopDAGScheduler => + // Cancel any active jobs + for (job <- activeJobs) { + val error = new SparkException("Job cancelled because SparkContext was shut down") + job.listener.jobFailed(error) + } + return true + } + return false + } + + /** Resubmit any failed stages. Ordinarily called after a small amount of time has passed since + * the last fetch failure. + */ + private[scheduler] def resubmitFailedStages() { + logInfo("Resubmitting failed stages") + clearCacheLocs() + val failed2 = failed.toArray + failed.clear() + for (stage <- failed2.sortBy(_.priority)) { + submitStage(stage) + } + } + + /** Check for waiting or failed stages which are now eligible for resubmission. + * Ordinarily run on every iteration of the event loop. + */ + private[scheduler] def submitWaitingStages() { + // TODO: We might want to run this less often, when we are sure that something has become + // runnable that wasn't before. + logTrace("Checking for newly runnable parent stages") + logTrace("running: " + running) + logTrace("waiting: " + waiting) + logTrace("failed: " + failed) + val waiting2 = waiting.toArray + waiting.clear() + for (stage <- waiting2.sortBy(_.priority)) { + submitStage(stage) + } + } + + /** * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure * events and responds by launching tasks. This runs in a dedicated thread and receives events @@ -251,77 +348,26 @@ class DAGScheduler(taskSched: TaskScheduler, while (true) { val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS) - val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability if (event != null) { logDebug("Got event of type " + event.getClass.getName) } - event match { - case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) => - val runId = nextRunId.getAndIncrement() - val finalStage = newStage(finalRDD, None, runId) - val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) - clearCacheLocs() - logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + - " output partitions") - logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) - if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { - // Compute very short actions like first() or take() with no parent stages locally. - runLocally(job) - } else { - activeJobs += job - resultStageToJob(finalStage) = job - submitStage(finalStage) - } - - case ExecutorLost(execId) => - handleExecutorLost(execId) - - case completion: CompletionEvent => - handleTaskCompletion(completion) - - case TaskSetFailed(taskSet, reason) => - abortStage(idToStage(taskSet.stageId), reason) - - case StopDAGScheduler => - // Cancel any active jobs - for (job <- activeJobs) { - val error = new SparkException("Job cancelled because SparkContext was shut down") - job.listener.jobFailed(error) - } + if (event != null) { + if (processEvent(event)) { return - - case null => - // queue.poll() timed out, ignore it + } } + val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability // Periodically resubmit failed stages if some map output fetches have failed and we have // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails, // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at // the same time, so we want to make sure we've identified all the reduce tasks that depend // on the failed node. if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { - logInfo("Resubmitting failed stages") - clearCacheLocs() - val failed2 = failed.toArray - failed.clear() - for (stage <- failed2.sortBy(_.priority)) { - submitStage(stage) - } + resubmitFailedStages } else { - // TODO: We might want to run this less often, when we are sure that something has become - // runnable that wasn't before. - logTrace("Checking for newly runnable parent stages") - logTrace("running: " + running) - logTrace("waiting: " + waiting) - logTrace("failed: " + failed) - val waiting2 = waiting.toArray - waiting.clear() - for (stage <- waiting2.sortBy(_.priority)) { - submitStage(stage) - } + submitWaitingStages } } } -- cgit v1.2.3 From 9c0bae75ade9e5b5a69077a5719adf4ee96e2c2e Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 30 Jan 2013 09:22:07 -0800 Subject: Change DAGSchedulerSuite to run DAGScheduler in the same Thread. --- .../scala/spark/scheduler/DAGSchedulerSuite.scala | 568 ++++++++++++--------- 1 file changed, 319 insertions(+), 249 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 6c577c2685..89173540d4 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -4,12 +4,12 @@ import scala.collection.mutable.{Map, HashMap} import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.AsyncAssertions import org.scalatest.concurrent.TimeLimitedTests import org.scalatest.mock.EasyMockSugar import org.scalatest.time.{Span, Seconds} import org.easymock.EasyMock._ +import org.easymock.Capture import org.easymock.EasyMock import org.easymock.{IAnswer, IArgumentMatcher} @@ -30,33 +30,55 @@ import spark.TaskEndReason import spark.{FetchFailed, Success} -class DAGSchedulerSuite extends FunSuite - with BeforeAndAfter with EasyMockSugar with TimeLimitedTests - with AsyncAssertions with spark.Logging { +class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests { - // If we crash the DAGScheduler thread, our test will probably hang. + // impose a time limit on this test in case we don't let the job finish. override val timeLimit = Span(5, Seconds) val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite") var scheduler: DAGScheduler = null - var w: Waiter = null val taskScheduler = mock[TaskScheduler] val blockManagerMaster = mock[BlockManagerMaster] var mapOutputTracker: MapOutputTracker = null var schedulerThread: Thread = null var schedulerException: Throwable = null + + /** Set of EasyMock argument matchers that match a TaskSet for a given RDD. + * We cache these so we do not create duplicate matchers for the same RDD. + * This allows us to easily setup a sequence of expectations for task sets for + * that RDD. + */ val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher] + + /** Set of cache locations to return from our mock BlockManagerMaster. + * Keys are (rdd ID, partition ID). Anything not present will return an empty + * list of cache locations silently. + */ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] - implicit val mocks = MockObjects(taskScheduler, blockManagerMaster) + /** JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which + * will only submit one job) from needing to explicitly track it. + */ + var lastJobWaiter: JobWaiter = null - def makeBlockManagerId(host: String): BlockManagerId = - BlockManagerId("exec-" + host, host, 12345) + /** Tell EasyMockSugar what mock objects we want to be configured by expecting {...} + * and whenExecuting {...} */ + implicit val mocks = MockObjects(taskScheduler, blockManagerMaster) + /** Utility function to reset mocks and set expectations on them. EasyMock wants mock objects + * to be reset after each time their expectations are set, and we tend to check mock object + * calls over a single call to DAGScheduler. + * + * We also set a default expectation here that blockManagerMaster.getLocations can be called + * and will return values from cacheLocations. + */ def resetExpecting(f: => Unit) { reset(taskScheduler) reset(blockManagerMaster) - expecting(f) + expecting { + expectGetLocations() + f + } } before { @@ -70,45 +92,30 @@ class DAGSchedulerSuite extends FunSuite whenExecuting { scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) } - w = new Waiter - schedulerException = null - schedulerThread = new Thread("DAGScheduler under test") { - override def run() { - try { - scheduler.run() - } catch { - case t: Throwable => - logError("Got exception in DAGScheduler: ", t) - schedulerException = t - } finally { - w.dismiss() - } - } - } - schedulerThread.start - logInfo("finished before") } after { - logInfo("started after") + assert(scheduler.processEvent(StopDAGScheduler)) resetExpecting { taskScheduler.stop() } whenExecuting { - scheduler.stop - schedulerThread.join - } - w.await() - if (schedulerException != null) { - throw new Exception("Exception caught from scheduler thread", schedulerException) + scheduler.stop() } System.clearProperty("spark.master.port") } - // Type of RDD we use for testing. Note that we should never call the real RDD compute methods. - // This is a pair RDD type so it can always be used in ShuffleDependencies. + def makeBlockManagerId(host: String): BlockManagerId = + BlockManagerId("exec-" + host, host, 12345) + + /** Type of RDD we use for testing. Note that we should never call the real RDD compute methods. + * This is a pair RDD type so it can always be used in ShuffleDependencies. */ type MyRDD = RDD[(Int, Int)] + /** Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and + * preferredLocations (if any) that are passed to them. They are deliberately not executable + * so we can test that DAGScheduler does not try to execute RDDs locally. + */ def makeRdd( numSplits: Int, dependencies: List[Dependency[_]], @@ -130,6 +137,9 @@ class DAGSchedulerSuite extends FunSuite } } + /** EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task + * is from a particular RDD. + */ def taskSetForRdd(rdd: MyRDD): TaskSet = { val matcher = taskSetMatchers.getOrElseUpdate(rdd, new IArgumentMatcher { @@ -149,6 +159,9 @@ class DAGSchedulerSuite extends FunSuite return null } + /** Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from + * cacheLocations. + */ def expectGetLocations(): Unit = { EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])). andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] { @@ -171,51 +184,106 @@ class DAGSchedulerSuite extends FunSuite }).anyTimes() } - def expectStageAnd(rdd: MyRDD, results: Seq[(TaskEndReason, Any)], - preferredLocations: Option[Seq[Seq[String]]] = None)(afterSubmit: TaskSet => Unit) { - // TODO: Remember which submission - EasyMock.expect(taskScheduler.submitTasks(taskSetForRdd(rdd))).andAnswer(new IAnswer[Unit] { - override def answer(): Unit = { - val taskSet = getCurrentArguments()(0).asInstanceOf[TaskSet] - for (task <- taskSet.tasks) { - task.generation = mapOutputTracker.getGeneration - } - afterSubmit(taskSet) - preferredLocations match { - case None => - for (taskLocs <- taskSet.tasks.map(_.preferredLocations)) { - w { assert(taskLocs.size === 0) } - } - case Some(locations) => - w { assert(locations.size === taskSet.tasks.size) } - for ((expectLocs, taskLocs) <- - taskSet.tasks.map(_.preferredLocations).zip(locations)) { - w { assert(expectLocs === taskLocs) } - } - } - w { assert(taskSet.tasks.size >= results.size)} - for ((result, i) <- results.zipWithIndex) { - if (i < taskSet.tasks.size) { - scheduler.taskEnded(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()) - } - } + /** Process the supplied event as if it were the top of the DAGScheduler event queue, expecting + * the scheduler not to exit. + * + * After processing the event, submit waiting stages as is done on most iterations of the + * DAGScheduler event loop. + */ + def runEvent(event: DAGSchedulerEvent) { + assert(!scheduler.processEvent(event)) + scheduler.submitWaitingStages() + } + + /** Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be + * called from a resetExpecting { ... } block. + * + * Returns a easymock Capture that will contain the task set after the stage is submitted. + * Most tests should use interceptStage() instead of this directly. + */ + def expectStage(rdd: MyRDD): Capture[TaskSet] = { + val taskSetCapture = new Capture[TaskSet] + taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd))) + return taskSetCapture + } + + /** Expect the supplied code snippet to submit a stage for the specified RDD. + * Return the resulting TaskSet. First marks all the tasks are belonging to the + * current MapOutputTracker generation. + */ + def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = { + var capture: Capture[TaskSet] = null + resetExpecting { + capture = expectStage(rdd) + } + whenExecuting { + f + } + val taskSet = capture.getValue + for (task <- taskSet.tasks) { + task.generation = mapOutputTracker.getGeneration + } + return taskSet + } + + /** Send the given CompletionEvent messages for the tasks in the TaskSet. */ + def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { + assert(taskSet.tasks.size >= results.size) + for ((result, i) <- results.zipWithIndex) { + if (i < taskSet.tasks.size) { + runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]())) } - }) + } } - def expectStage(rdd: MyRDD, results: Seq[(TaskEndReason, Any)], - preferredLocations: Option[Seq[Seq[String]]] = None) { - expectStageAnd(rdd, results, preferredLocations) { _ => } + /** Assert that the supplied TaskSet has exactly the given preferredLocations. */ + def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) { + assert(locations.size === taskSet.tasks.size) + for ((expectLocs, taskLocs) <- + taskSet.tasks.map(_.preferredLocations).zip(locations)) { + assert(expectLocs === taskLocs) + } } - def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): Array[Int] = { - return scheduler.runJob[(Int, Int), Int]( + /** When we submit dummy Jobs, this is the compute function we supply. Except in a local test + * below, we do not expect this function to ever be executed; instead, we will return results + * directly through CompletionEvents. + */ + def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int = + it.next._1.asInstanceOf[Int] + + + /** Start a job to compute the given RDD. Returns the JobWaiter that will + * collect the result of the job via callbacks from DAGScheduler. */ + def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): JobWaiter = { + val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int]( rdd, - (context: TaskContext, it: Iterator[(Int, Int)]) => it.next._1.asInstanceOf[Int], + jobComputeFunc, (0 to (rdd.splits.size - 1)), "test-site", allowLocal ) + lastJobWaiter = waiter + runEvent(toSubmit) + return waiter + } + + /** Assert that a job we started has failed. */ + def expectJobException(waiter: JobWaiter = lastJobWaiter) { + waiter.getResult match { + case JobSucceeded(_) => fail() + case JobFailed(_) => return + } + } + + /** Assert that a job we started has succeeded and has the given result. */ + def expectJobResult(expected: Array[Int], waiter: JobWaiter = lastJobWaiter) { + waiter.getResult match { + case JobSucceeded(answer) => + assert(expected === answer.asInstanceOf[Seq[Int]].toArray ) + case JobFailed(_) => + fail() + } } def makeMapStatus(host: String, reduces: Int): MapStatus = @@ -223,24 +291,14 @@ class DAGSchedulerSuite extends FunSuite test("zero split job") { val rdd = makeRdd(0, Nil) - resetExpecting { - expectGetLocations() - // deliberately expect no stages to be submitted - } - whenExecuting { - assert(submitRdd(rdd) === Array[Int]()) - } + assert(scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false) === Array[Int]()) } test("run trivial job") { val rdd = makeRdd(1, Nil) - resetExpecting { - expectGetLocations() - expectStage(rdd, List( (Success, 42) )) - } - whenExecuting { - assert(submitRdd(rdd) === Array(42)) - } + val taskSet = interceptStage(rdd) { submitRdd(rdd) } + respondToTaskSet(taskSet, List( (Success, 42) )) + expectJobResult(Array(42)) } test("local job") { @@ -251,51 +309,34 @@ class DAGSchedulerSuite extends FunSuite override def getPreferredLocations(split: Split) = Nil override def toString = "DAGSchedulerSuite Local RDD" } - resetExpecting { - expectGetLocations() - // deliberately expect no stages to be submitted - } - whenExecuting { - assert(submitRdd(rdd, true) === Array(42)) - } + submitRdd(rdd, true) + expectJobResult(Array(42)) } test("run trivial job w/ dependency") { val baseRdd = makeRdd(1, Nil) val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) - resetExpecting { - expectGetLocations() - expectStage(finalRdd, List( (Success, 42) )) - } - whenExecuting { - assert(submitRdd(finalRdd) === Array(42)) - } + val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) } + respondToTaskSet(taskSet, List( (Success, 42) )) + expectJobResult(Array(42)) } - test("location preferences w/ dependency") { + test("cache location preferences w/ dependency") { val baseRdd = makeRdd(1, Nil) val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) - resetExpecting { - expectGetLocations() - cacheLocations(baseRdd.id -> 0) = - Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) - expectStage(finalRdd, List( (Success, 42) ), - Some(List(Seq("hostA", "hostB")))) - } - whenExecuting { - assert(submitRdd(finalRdd) === Array(42)) - } + cacheLocations(baseRdd.id -> 0) = + Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) + val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) } + expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB"))) + respondToTaskSet(taskSet, List( (Success, 42) )) + expectJobResult(Array(42)) } test("trivial job failure") { val rdd = makeRdd(1, Nil) - resetExpecting { - expectGetLocations() - expectStageAnd(rdd, List()) { taskSet => scheduler.taskSetFailed(taskSet, "test failure") } - } - whenExecuting(taskScheduler, blockManagerMaster) { - intercept[SparkException] { submitRdd(rdd) } - } + val taskSet = interceptStage(rdd) { submitRdd(rdd) } + runEvent(TaskSetFailed(taskSet, "test failure")) + expectJobException() } test("run trivial shuffle") { @@ -304,20 +345,17 @@ class DAGSchedulerSuite extends FunSuite val shuffleId = shuffleDep.shuffleId val reduceRdd = makeRdd(1, List(shuffleDep)) - resetExpecting { - expectGetLocations() - expectStage(shuffleMapRdd, List( + val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) } + val secondStage = interceptStage(reduceRdd) { + respondToTaskSet(firstStage, List( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)) )) - expectStageAnd(reduceRdd, List( (Success, 42) )) { _ => - w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) } - } - } - whenExecuting { - assert(submitRdd(reduceRdd) === Array(42)) } + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + respondToTaskSet(secondStage, List( (Success, 42) )) + expectJobResult(Array(42)) } test("run trivial shuffle with fetch failure") { @@ -326,28 +364,32 @@ class DAGSchedulerSuite extends FunSuite val shuffleId = shuffleDep.shuffleId val reduceRdd = makeRdd(2, List(shuffleDep)) - resetExpecting { - expectGetLocations() - expectStage(shuffleMapRdd, List( + val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) } + val secondStage = interceptStage(reduceRdd) { + respondToTaskSet(firstStage, List( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)) )) + } + resetExpecting { blockManagerMaster.removeExecutor("exec-hostA") - expectStage(reduceRdd, List( + } + whenExecuting { + respondToTaskSet(secondStage, List( (Success, 42), (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null) )) - // partial recompute - expectStage(shuffleMapRdd, List( (Success, makeMapStatus("hostA", 1)) )) - expectStageAnd(reduceRdd, List( (Success, 43) )) { _ => - w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"), - makeBlockManagerId("hostB"))) } - } } - whenExecuting { - assert(submitRdd(reduceRdd) === Array(42, 43)) + val thirdStage = interceptStage(shuffleMapRdd) { + scheduler.resubmitFailedStages() + } + val fourthStage = interceptStage(reduceRdd) { + respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) )) } + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + respondToTaskSet(fourthStage, List( (Success, 43) )) + expectJobResult(Array(42, 43)) } test("ignore late map task completions") { @@ -356,63 +398,64 @@ class DAGSchedulerSuite extends FunSuite val shuffleId = shuffleDep.shuffleId val reduceRdd = makeRdd(2, List(shuffleDep)) + val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) } + val oldGeneration = mapOutputTracker.getGeneration resetExpecting { - expectGetLocations() - expectStageAnd(shuffleMapRdd, List( - (Success, makeMapStatus("hostA", 1)) - )) { taskSet => - val newGeneration = mapOutputTracker.getGeneration + 1 - scheduler.executorLost("exec-hostA") - val noAccum = Map[Long, Any]() - // We rely on the event queue being ordered and increasing the generation number by 1 - // should be ignored for being too old - scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum) - // should work because it's a non-failed host - scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum) - // should be ignored for being too old - scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum) - // should be ignored (not end the stage) because it's too old - scheduler.taskEnded(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum) - taskSet.tasks(1).generation = newGeneration - scheduler.taskEnded(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum) - } blockManagerMaster.removeExecutor("exec-hostA") - expectStageAnd(reduceRdd, List( - (Success, 42), (Success, 43) - )) { _ => - w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) } - } } whenExecuting { - assert(submitRdd(reduceRdd) === Array(42, 43)) - } + runEvent(ExecutorLost("exec-hostA")) + } + val newGeneration = mapOutputTracker.getGeneration + assert(newGeneration > oldGeneration) + val noAccum = Map[Long, Any]() + // We rely on the event queue being ordered and increasing the generation number by 1 + // should be ignored for being too old + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum)) + // should work because it's a non-failed host + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum)) + // should be ignored for being too old + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum)) + taskSet.tasks(1).generation = newGeneration + val secondStage = interceptStage(reduceRdd) { + runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum)) + } + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) )) + expectJobResult(Array(42, 43)) } - test("run trivial shuffle with out-of-band failure") { + test("run trivial shuffle with out-of-band failure and retry") { val shuffleMapRdd = makeRdd(2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) val shuffleId = shuffleDep.shuffleId val reduceRdd = makeRdd(1, List(shuffleDep)) + + val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) } resetExpecting { - expectGetLocations() blockManagerMaster.removeExecutor("exec-hostA") - expectStageAnd(shuffleMapRdd, List( + } + whenExecuting { + runEvent(ExecutorLost("exec-hostA")) + } + // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks + // rather than marking it is as failed and waiting. + val secondStage = interceptStage(shuffleMapRdd) { + respondToTaskSet(firstStage, List( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)) - )) { _ => scheduler.executorLost("exec-hostA") } - expectStage(shuffleMapRdd, List( - (Success, makeMapStatus("hostC", 1)) )) - expectStageAnd(reduceRdd, List( (Success, 42) )) { _ => - w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostC"), - makeBlockManagerId("hostB"))) } - } } - whenExecuting { - assert(submitRdd(reduceRdd) === Array(42)) + val thirdStage = interceptStage(reduceRdd) { + respondToTaskSet(secondStage, List( + (Success, makeMapStatus("hostC", 1)) + )) } + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + respondToTaskSet(thirdStage, List( (Success, 42) )) + expectJobResult(Array(42)) } test("recursive shuffle failures") { @@ -422,34 +465,42 @@ class DAGSchedulerSuite extends FunSuite val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) val finalRdd = makeRdd(1, List(shuffleDepTwo)) - resetExpecting { - expectGetLocations() - expectStage(shuffleOneRdd, List( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)) + val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) } + val secondStage = interceptStage(shuffleTwoRdd) { + respondToTaskSet(firstStage, List( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)) )) - expectStage(shuffleTwoRdd, List( + } + val thirdStage = interceptStage(finalRdd) { + respondToTaskSet(secondStage, List( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostC", 1)) )) + } + resetExpecting { blockManagerMaster.removeExecutor("exec-hostA") - expectStage(finalRdd, List( + } + whenExecuting { + respondToTaskSet(thirdStage, List( (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) )) - // triggers a partial recompute of the first stage, then the second - expectStage(shuffleOneRdd, List( - (Success, makeMapStatus("hostA", 1)) + } + val recomputeOne = interceptStage(shuffleOneRdd) { + scheduler.resubmitFailedStages + } + val recomputeTwo = interceptStage(shuffleTwoRdd) { + respondToTaskSet(recomputeOne, List( + (Success, makeMapStatus("hostA", 2)) )) - expectStage(shuffleTwoRdd, List( + } + val finalStage = interceptStage(finalRdd) { + respondToTaskSet(recomputeTwo, List( (Success, makeMapStatus("hostA", 1)) )) - expectStage(finalRdd, List( - (Success, 42) - )) - } - whenExecuting { - assert(submitRdd(finalRdd) === Array(42)) } + respondToTaskSet(finalStage, List( (Success, 42) )) + expectJobResult(Array(42)) } test("cached post-shuffle") { @@ -459,35 +510,41 @@ class DAGSchedulerSuite extends FunSuite val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) val finalRdd = makeRdd(1, List(shuffleDepTwo)) - resetExpecting { - expectGetLocations() - expectStage(shuffleOneRdd, List( + val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) } + cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) + cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) + val secondShuffleStage = interceptStage(shuffleTwoRdd) { + respondToTaskSet(firstShuffleStage, List( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)) + )) + } + val reduceStage = interceptStage(finalRdd) { + respondToTaskSet(secondShuffleStage, List( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)) )) - expectStageAnd(shuffleTwoRdd, List( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostC", 1)) - )){ _ => - cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) - cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) - } + } + resetExpecting { blockManagerMaster.removeExecutor("exec-hostA") - expectStage(finalRdd, List( + } + whenExecuting { + respondToTaskSet(reduceStage, List( (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) )) - // since we have a cached copy of the missing split of shuffleTwoRdd, we shouldn't - // immediately try to rerun shuffleOneRdd: - expectStage(shuffleTwoRdd, List( + } + // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. + val recomputeTwo = interceptStage(shuffleTwoRdd) { + scheduler.resubmitFailedStages() + } + expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD"))) + val finalRetry = interceptStage(finalRdd) { + respondToTaskSet(recomputeTwo, List( (Success, makeMapStatus("hostD", 1)) - ), Some(Seq(List("hostD")))) - expectStage(finalRdd, List( - (Success, 42) )) } - whenExecuting { - assert(submitRdd(finalRdd) === Array(42)) - } + respondToTaskSet(finalRetry, List( (Success, 42) )) + expectJobResult(Array(42)) } test("cached post-shuffle but fails") { @@ -497,45 +554,58 @@ class DAGSchedulerSuite extends FunSuite val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) val finalRdd = makeRdd(1, List(shuffleDepTwo)) - resetExpecting { - expectGetLocations() - expectStage(shuffleOneRdd, List( + val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) } + cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) + cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) + val secondShuffleStage = interceptStage(shuffleTwoRdd) { + respondToTaskSet(firstShuffleStage, List( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)) + )) + } + val reduceStage = interceptStage(finalRdd) { + respondToTaskSet(secondShuffleStage, List( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)) )) - expectStageAnd(shuffleTwoRdd, List( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostC", 1)) - )){ _ => - cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) - cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) - } + } + resetExpecting { blockManagerMaster.removeExecutor("exec-hostA") - expectStage(finalRdd, List( + } + whenExecuting { + respondToTaskSet(reduceStage, List( (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) )) - // since we have a cached copy of the missing split of shuffleTwoRdd, we shouldn't - // immediately try to rerun shuffleOneRdd: - expectStageAnd(shuffleTwoRdd, List( - (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null) - ), Some(Seq(List("hostD")))) { _ => - w { - intercept[FetchFailedException]{ - mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0) - } - } - cacheLocations.remove(shuffleTwoRdd.id -> 0) - } - // after that fetch failure, we should refetch the cache locations and try to recompute - // the whole chain. Note that we will ignore that a fetch failure previously occured on - // this host. - expectStage(shuffleOneRdd, List( (Success, makeMapStatus("hostA", 1)) )) - expectStage(shuffleTwoRdd, List( (Success, makeMapStatus("hostA", 1)) )) - expectStage(finalRdd, List( (Success, 42) )) } - whenExecuting { - assert(submitRdd(finalRdd) === Array(42)) + val recomputeTwoCached = interceptStage(shuffleTwoRdd) { + scheduler.resubmitFailedStages() + } + expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD"))) + intercept[FetchFailedException]{ + mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0) + } + + // Simulate the shuffle input data failing to be cached. + cacheLocations.remove(shuffleTwoRdd.id -> 0) + respondToTaskSet(recomputeTwoCached, List( + (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null) + )) + + // After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit + // everything. + val recomputeOne = interceptStage(shuffleOneRdd) { + scheduler.resubmitFailedStages() } + // We use hostA here to make sure DAGScheduler doesn't think it's still dead. + val recomputeTwoUncached = interceptStage(shuffleTwoRdd) { + respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) )) + } + expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]())) + val finalRetry = interceptStage(finalRdd) { + respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) )) + + } + respondToTaskSet(finalRetry, List( (Success, 42) )) + expectJobResult(Array(42)) } } - -- cgit v1.2.3 From 7f51458774ce4561f1df3ba9b68704c3f63852f3 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 30 Jan 2013 09:34:53 -0800 Subject: Comment at top of DAGSchedulerSuite --- .../test/scala/spark/scheduler/DAGSchedulerSuite.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 89173540d4..c31e2e7064 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -30,9 +30,22 @@ import spark.TaskEndReason import spark.{FetchFailed, Success} +/** + * Tests for DAGScheduler. These tests directly call the event processing functinos in DAGScheduler + * rather than spawning an event loop thread as happens in the real code. They use EasyMock + * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are + * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead + * host notifications are sent). In addition, tests may check for side effects on a non-mocked + * MapOutputTracker instance. + * + * Tests primarily consist of running DAGScheduler#processEvent and + * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet) + * and capturing the resulting TaskSets from the mock TaskScheduler. + */ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests { - // impose a time limit on this test in case we don't let the job finish. + // impose a time limit on this test in case we don't let the job finish, in which case + // JobWaiter#getResult will hang. override val timeLimit = Span(5, Seconds) val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite") -- cgit v1.2.3 From f7de6978c14a331683e4a341fccd6e4c5e9fa523 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 29 Jan 2013 14:03:05 -0800 Subject: Use Mesos ExecutorIDs to hold SlaveIDs. Then we can safely use the Mesos ExecutorID as a Spark ExecutorID. --- .../spark/executor/MesosExecutorBackend.scala | 6 ++++- .../scheduler/mesos/MesosSchedulerBackend.scala | 30 ++++++++++++---------- 2 files changed, 21 insertions(+), 15 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala index 1ef88075ad..b981b26916 100644 --- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala @@ -32,7 +32,11 @@ private[spark] class MesosExecutorBackend(executor: Executor) logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) this.driver = driver val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) - executor.initialize(executorInfo.getExecutorId.getValue, slaveInfo.getHostname, properties) + executor.initialize( + slaveInfo.getId.getValue + "-" + executorInfo.getExecutorId.getValue, + slaveInfo.getHostname, + properties + ) } override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index f3467db86b..eab1c60e0b 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -51,7 +51,7 @@ private[spark] class MesosSchedulerBackend( val taskIdToSlaveId = new HashMap[Long, String] // An ExecutorInfo for our tasks - var executorInfo: ExecutorInfo = null + var execArgs: Array[Byte] = null override def start() { synchronized { @@ -70,12 +70,11 @@ private[spark] class MesosSchedulerBackend( } }.start() - executorInfo = createExecutorInfo() waitForRegister() } } - def createExecutorInfo(): ExecutorInfo = { + def createExecutorInfo(execId: String): ExecutorInfo = { val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( "Spark home is not set; set it through the spark.home system " + "property, the SPARK_HOME environment variable or the SparkContext constructor")) @@ -97,7 +96,7 @@ private[spark] class MesosSchedulerBackend( .setEnvironment(environment) .build() ExecutorInfo.newBuilder() - .setExecutorId(ExecutorID.newBuilder().setValue("default").build()) + .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) .addResources(memory) @@ -109,17 +108,20 @@ private[spark] class MesosSchedulerBackend( * containing all the spark.* system properties in the form of (String, String) pairs. */ private def createExecArg(): Array[Byte] = { - val props = new HashMap[String, String] - val iterator = System.getProperties.entrySet.iterator - while (iterator.hasNext) { - val entry = iterator.next - val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.")) { - props(key) = value + if (execArgs == null) { + val props = new HashMap[String, String] + val iterator = System.getProperties.entrySet.iterator + while (iterator.hasNext) { + val entry = iterator.next + val (key, value) = (entry.getKey.toString, entry.getValue.toString) + if (key.startsWith("spark.")) { + props(key) = value + } } + // Serialize the map as an array of (String, String) pairs + execArgs = Utils.serialize(props.toArray) } - // Serialize the map as an array of (String, String) pairs - return Utils.serialize(props.toArray) + return execArgs } override def offerRescinded(d: SchedulerDriver, o: OfferID) {} @@ -216,7 +218,7 @@ private[spark] class MesosSchedulerBackend( return MesosTaskInfo.newBuilder() .setTaskId(taskId) .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) - .setExecutor(executorInfo) + .setExecutor(createExecutorInfo(slaveId)) .setName(task.name) .addResources(cpuResource) .setData(ByteString.copyFrom(task.serializedTask)) -- cgit v1.2.3 From 252845d3046034d6e779bd7245d2f876debba8fd Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 30 Jan 2013 10:38:06 -0800 Subject: Remove remants of attempt to use slaveId-executorId in MesosExecutorBackend --- core/src/main/scala/spark/executor/MesosExecutorBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala index b981b26916..818d6d1dda 100644 --- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala @@ -33,7 +33,7 @@ private[spark] class MesosExecutorBackend(executor: Executor) this.driver = driver val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) executor.initialize( - slaveInfo.getId.getValue + "-" + executorInfo.getExecutorId.getValue, + executorInfo.getExecutorId.getValue, slaveInfo.getHostname, properties ) -- cgit v1.2.3 From 871476d506a2d543482defb923a42a2a01f206ab Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Wed, 30 Jan 2013 16:56:46 -0600 Subject: Include message and exitStatus if availalbe. --- core/src/main/scala/spark/deploy/worker/Worker.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 5a83a42daf..8b41620d98 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -134,7 +134,9 @@ private[spark] class Worker( val fullId = jobId + "/" + execId if (ExecutorState.isFinished(state)) { val executor = executors(fullId) - logInfo("Executor " + fullId + " finished with state " + state) + logInfo("Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) finishedExecutors(fullId) = executor executors -= fullId coresUsed -= executor.cores -- cgit v1.2.3 From c1df24d0850b0ac89f35f1a47ce6b2fb5b95df0a Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 30 Jan 2013 18:51:14 -0800 Subject: rename Slaves --> Executor --- core/src/main/scala/spark/SparkContext.scala | 6 +++--- core/src/main/scala/spark/storage/BlockManagerUI.scala | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index a09eca1dd0..39e3555de8 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -468,7 +468,7 @@ class SparkContext( * Return a map from the slave to the max memory available for caching and the remaining * memory available for caching. */ - def getSlavesMemoryStatus: Map[String, (Long, Long)] = { + def getExecutorMemoryStatus: Map[String, (Long, Long)] = { env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => (blockManagerId.ip + ":" + blockManagerId.port, mem) } @@ -479,13 +479,13 @@ class SparkContext( * they take, etc. */ def getRDDStorageInfo : Array[RDDInfo] = { - StorageUtils.rddInfoFromStorageStatus(getSlavesStorageStatus, this) + StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this) } /** * Return information about blocks stored in all of the slaves */ - def getSlavesStorageStatus : Array[StorageStatus] = { + def getExecutorStorageStatus : Array[StorageStatus] = { env.blockManager.master.getStorageStatus } diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 52f6d1b657..9e6721ec17 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -45,7 +45,7 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, path("") { completeWith { // Request the current storage status from the Master - val storageStatusList = sc.getSlavesStorageStatus + val storageStatusList = sc.getExecutorStorageStatus // Calculate macro-level statistics val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) @@ -60,7 +60,7 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, parameter("id") { id => completeWith { val prefix = "rdd_" + id.toString - val storageStatusList = sc.getSlavesStorageStatus + val storageStatusList = sc.getExecutorStorageStatus val filteredStorageStatusList = StorageUtils. filterStorageStatusByPrefix(storageStatusList, prefix) val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head -- cgit v1.2.3 From fe3eceab5724bec0103471eb905bb9701120b04a Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Thu, 31 Jan 2013 13:30:41 -0800 Subject: Remove activation of profiles by default See the discussion at https://github.com/mesos/spark/pull/355 for why default profile activation is a problem. --- bagel/pom.xml | 11 ----------- core/pom.xml | 11 ----------- examples/pom.xml | 11 ----------- pom.xml | 11 ----------- repl-bin/pom.xml | 11 ----------- repl/pom.xml | 11 ----------- streaming/pom.xml | 11 ----------- 7 files changed, 77 deletions(-) (limited to 'core') diff --git a/bagel/pom.xml b/bagel/pom.xml index 5f58347204..a8256a6e8b 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -45,11 +45,6 @@ hadoop1 - - - !hadoopVersion - - org.spark-project @@ -77,12 +72,6 @@ hadoop2 - - - hadoopVersion - 2 - - org.spark-project diff --git a/core/pom.xml b/core/pom.xml index 862d3ec37a..873e8a1d0f 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -163,11 +163,6 @@ hadoop1 - - - !hadoopVersion - - org.apache.hadoop @@ -220,12 +215,6 @@ hadoop2 - - - hadoopVersion - 2 - - org.apache.hadoop diff --git a/examples/pom.xml b/examples/pom.xml index 4d43103475..f43af670c6 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -50,11 +50,6 @@ hadoop1 - - - !hadoopVersion - - org.spark-project @@ -88,12 +83,6 @@ hadoop2 - - - hadoopVersion - 2 - - org.spark-project diff --git a/pom.xml b/pom.xml index 3ea989a082..c6b9012dc6 100644 --- a/pom.xml +++ b/pom.xml @@ -499,11 +499,6 @@ hadoop1 - - - !hadoopVersion - - 1 @@ -521,12 +516,6 @@ hadoop2 - - - hadoopVersion - 2 - - 2 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index da91c0f3ab..0667b71cc7 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -70,11 +70,6 @@ hadoop1 - - - !hadoopVersion - - hadoop1 @@ -115,12 +110,6 @@ hadoop2 - - - hadoopVersion - 2 - - hadoop2 diff --git a/repl/pom.xml b/repl/pom.xml index 2dc96beaf5..4a296fa630 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -72,11 +72,6 @@ hadoop1 - - - !hadoopVersion - - hadoop1 @@ -128,12 +123,6 @@ hadoop2 - - - hadoopVersion - 2 - - hadoop2 diff --git a/streaming/pom.xml b/streaming/pom.xml index 3dae815e1a..6ee7e59df3 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -83,11 +83,6 @@ hadoop1 - - - !hadoopVersion - - org.spark-project @@ -115,12 +110,6 @@ hadoop2 - - - hadoopVersion - 2 - - org.spark-project -- cgit v1.2.3 From 418e36caa8fcd9a70026ab762ec709732fdebd6b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Thu, 31 Jan 2013 17:18:33 -0600 Subject: Add more private declarations. --- core/src/main/scala/spark/MapOutputTracker.scala | 2 +- .../scala/spark/deploy/master/MasterWebUI.scala | 22 ++++------- .../main/scala/spark/scheduler/DAGScheduler.scala | 46 +++++++++++----------- .../scala/spark/scheduler/ShuffleMapTask.scala | 3 +- .../spark/scheduler/cluster/ClusterScheduler.scala | 2 +- .../spark/scheduler/cluster/TaskSetManager.scala | 19 ++++----- .../spark/scheduler/local/LocalScheduler.scala | 4 +- .../main/scala/spark/util/MetadataCleaner.scala | 10 ++--- 8 files changed, 49 insertions(+), 59 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index aaf433b324..4735207585 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -170,7 +170,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea } } - def cleanup(cleanupTime: Long) { + private def cleanup(cleanupTime: Long) { mapStatuses.clearOldValues(cleanupTime) cachedSerializedStatuses.clearOldValues(cleanupTime) } diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index a01774f511..529f72e9da 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -45,13 +45,9 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) => val future = master ? RequestMasterState val jobInfo = for (masterState <- future.mapTo[MasterState]) yield { - masterState.activeJobs.find(_.id == jobId) match { - case Some(job) => job - case _ => masterState.completedJobs.find(_.id == jobId) match { - case Some(job) => job - case _ => null - } - } + masterState.activeJobs.find(_.id == jobId).getOrElse({ + masterState.completedJobs.find(_.id == jobId).getOrElse(null) + }) } respondWithMediaType(MediaTypes.`application/json`) { ctx => ctx.complete(jobInfo.mapTo[JobInfo]) @@ -61,14 +57,10 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct val future = master ? RequestMasterState future.map { state => val masterState = state.asInstanceOf[MasterState] - - masterState.activeJobs.find(_.id == jobId) match { - case Some(job) => spark.deploy.master.html.job_details.render(job) - case _ => masterState.completedJobs.find(_.id == jobId) match { - case Some(job) => spark.deploy.master.html.job_details.render(job) - case _ => null - } - } + val job = masterState.activeJobs.find(_.id == jobId).getOrElse({ + masterState.completedJobs.find(_.id == jobId).getOrElse(null) + }) + spark.deploy.master.html.job_details.render(job) } } } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b130be6a38..14f61f7e87 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -97,7 +97,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } }.start() - def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { + private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { @@ -107,7 +107,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with cacheLocs(rdd.id) } - def clearCacheLocs() { + private def clearCacheLocs() { cacheLocs.clear() } @@ -116,7 +116,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * The priority value passed in will be used if the stage doesn't already exist with * a lower priority (we assume that priorities always increase across jobs for now). */ - def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = { + private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => @@ -131,11 +131,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * as a result stage for the final RDD used directly in an action. The stage will also be given * the provided priority. */ - def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = { - // Kind of ugly: need to register RDDs with the cache and map output tracker here - // since we can't do it in the RDD constructor because # of splits is unknown - logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") + private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = { if (shuffleDep != None) { + // Kind of ugly: need to register RDDs with the cache and map output tracker here + // since we can't do it in the RDD constructor because # of splits is unknown + logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) } val id = nextStageId.getAndIncrement() @@ -148,7 +148,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * Get or create the list of parent stages for a given RDD. The stages will be assigned the * provided priority if they haven't already been created with a lower priority. */ - def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = { + private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] def visit(r: RDD[_]) { @@ -170,7 +170,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with parents.toList } - def getMissingParentStages(stage: Stage): List[Stage] = { + private def getMissingParentStages(stage: Stage): List[Stage] = { val missing = new HashSet[Stage] val visited = new HashSet[RDD[_]] def visit(rdd: RDD[_]) { @@ -241,7 +241,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * events and responds by launching tasks. This runs in a dedicated thread and receives events * via the eventQueue. */ - def run() { + private def run() { SparkEnv.set(env) while (true) { @@ -326,7 +326,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * We run the operation in a separate thread just in case it takes a bunch of time, so that we * don't block the DAGScheduler event loop or other concurrent jobs. */ - def runLocally(job: ActiveJob) { + private def runLocally(job: ActiveJob) { logInfo("Computing the requested partition locally") new Thread("Local computation of job " + job.runId) { override def run() { @@ -349,13 +349,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with }.start() } - def submitStage(stage: Stage) { + /** Submits stage, but first recursively submits any missing parents. */ + private def submitStage(stage: Stage) { logDebug("submitStage(" + stage + ")") if (!waiting(stage) && !running(stage) && !failed(stage)) { val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) if (missing == Nil) { - logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents") + logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") submitMissingTasks(stage) running += stage } else { @@ -367,7 +368,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } } - def submitMissingTasks(stage: Stage) { + /** Called when stage's parents are available and we can now do its task. */ + private def submitMissingTasks(stage: Stage) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) @@ -388,7 +390,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } } if (tasks.size > 0) { - logInfo("Submitting " + tasks.size + " missing tasks from " + stage) + logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") myPending ++= tasks logDebug("New pending tasks: " + myPending) taskSched.submitTasks( @@ -407,7 +409,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * Responds to a task finishing. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. */ - def handleTaskCompletion(event: CompletionEvent) { + private def handleTaskCompletion(event: CompletionEvent) { val task = event.task val stage = idToStage(task.stageId) @@ -492,7 +494,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with waiting --= newlyRunnable running ++= newlyRunnable for (stage <- newlyRunnable.sortBy(_.id)) { - logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable") + logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") submitMissingTasks(stage) } } @@ -541,7 +543,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * Optionally the generation during which the failure was caught can be passed to avoid allowing * stray fetch failures from possibly retriggering the detection of a node as lost. */ - def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) { + private def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) { val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) { failedGeneration(execId) = currentGeneration @@ -567,7 +569,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ - def abortStage(failedStage: Stage, reason: String) { + private def abortStage(failedStage: Stage, reason: String) { val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) @@ -583,7 +585,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with /** * Return true if one of stage's ancestors is target. */ - def stageDependsOn(stage: Stage, target: Stage): Boolean = { + private def stageDependsOn(stage: Stage, target: Stage): Boolean = { if (stage == target) { return true } @@ -610,7 +612,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with visitedRdds.contains(target.rdd) } - def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { + private def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) if (cached != Nil) { @@ -636,7 +638,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return Nil } - def cleanup(cleanupTime: Long) { + private def cleanup(cleanupTime: Long) { var sizeBefore = idToStage.size idToStage.clearOldValues(cleanupTime) logInfo("idToStage " + sizeBefore + " --> " + idToStage.size) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 83641a2a84..b701b67c89 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -127,7 +127,6 @@ private[spark] class ShuffleMapTask( val bucketId = dep.partitioner.getPartition(pair._1) buckets(bucketId) += pair } - val bucketIterators = buckets.map(_.iterator) val compressedSizes = new Array[Byte](numOutputSplits) @@ -135,7 +134,7 @@ private[spark] class ShuffleMapTask( for (i <- 0 until numOutputSplits) { val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i // Get a Scala iterator from Java map - val iter: Iterator[(Any, Any)] = bucketIterators(i) + val iter: Iterator[(Any, Any)] = buckets(i).iterator val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) compressedSizes(i) = MapOutputTracker.compressSize(size) } diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 0b4177805b..1e4fbdb874 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -86,7 +86,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - def submitTasks(taskSet: TaskSet) { + override def submitTasks(taskSet: TaskSet) { val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 26201ad0dd..3dabdd76b1 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -17,10 +17,7 @@ import java.nio.ByteBuffer /** * Schedules the tasks within a single TaskSet in the ClusterScheduler. */ -private[spark] class TaskSetManager( - sched: ClusterScheduler, - val taskSet: TaskSet) - extends Logging { +private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging { // Maximum time to wait to run a task in a preferred location (in ms) val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong @@ -100,7 +97,7 @@ private[spark] class TaskSetManager( } // Add a task to all the pending-task lists that it should be on. - def addPendingTask(index: Int) { + private def addPendingTask(index: Int) { val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive if (locations.size == 0) { pendingTasksWithNoPrefs += index @@ -115,7 +112,7 @@ private[spark] class TaskSetManager( // Return the pending tasks list for a given host, or an empty list if // there is no map entry for that host - def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { + private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { pendingTasksForHost.getOrElse(host, ArrayBuffer()) } @@ -123,7 +120,7 @@ private[spark] class TaskSetManager( // Return None if the list is empty. // This method also cleans up any tasks in the list that have already // been launched, since we want that to happen lazily. - def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { + private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { while (!list.isEmpty) { val index = list.last list.trimEnd(1) @@ -137,7 +134,7 @@ private[spark] class TaskSetManager( // Return a speculative task for a given host if any are available. The task should not have an // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the // task must have a preference for this host (or no preferred locations at all). - def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { + private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { val hostsAlive = sched.hostsAlive speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set val localTask = speculatableTasks.find { @@ -162,7 +159,7 @@ private[spark] class TaskSetManager( // Dequeue a pending task for a given node and return its index. // If localOnly is set to false, allow non-local tasks as well. - def findTask(host: String, localOnly: Boolean): Option[Int] = { + private def findTask(host: String, localOnly: Boolean): Option[Int] = { val localTask = findTaskFromList(getPendingTasksForHost(host)) if (localTask != None) { return localTask @@ -184,7 +181,7 @@ private[spark] class TaskSetManager( // Does a host count as a preferred location for a task? This is true if // either the task has preferred locations and this host is one, or it has // no preferred locations (in which we still count the launch as preferred). - def isPreferredLocation(task: Task[_], host: String): Boolean = { + private def isPreferredLocation(task: Task[_], host: String): Boolean = { val locs = task.preferredLocations return (locs.contains(host) || locs.isEmpty) } @@ -335,7 +332,7 @@ private[spark] class TaskSetManager( if (numFailures(index) > MAX_TASK_FAILURES) { logError("Task %s:%d failed more than %d times; aborting job".format( taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES)) + abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) } } } else { diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 9ff7c02097..482d1cc853 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon } def runTask(task: Task[_], idInJob: Int, attemptId: Int) { - logInfo("Running task " + idInJob) + logInfo("Running " + task) // Set the Spark execution environment for the worker thread SparkEnv.set(env) try { @@ -80,7 +80,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon val resultToReturn = ser.deserialize[Any](ser.serialize(result)) val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( ser.serialize(Accumulators.values)) - logInfo("Finished task " + idInJob) + logInfo("Finished " + task) // If the threadpool has not already been shutdown, notify DAGScheduler if (!Thread.currentThread().isInterrupted) diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index eaff7ae581..a342d378ff 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -9,12 +9,12 @@ import spark.Logging * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) */ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delaySeconds = MetadataCleaner.getDelaySeconds - val periodSeconds = math.max(10, delaySeconds / 10) - val timer = new Timer(name + " cleanup timer", true) + private val delaySeconds = MetadataCleaner.getDelaySeconds + private val periodSeconds = math.max(10, delaySeconds / 10) + private val timer = new Timer(name + " cleanup timer", true) - val task = new TimerTask { - def run() { + private val task = new TimerTask { + override def run() { try { cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) logInfo("Ran metadata cleaner for " + name) -- cgit v1.2.3 From 782187c21047ee31728bdb173a2b7ee708cef77b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Thu, 31 Jan 2013 18:27:25 -0600 Subject: Once we find a split with no block, we don't have to look for more. --- .../main/scala/spark/scheduler/DAGScheduler.scala | 23 +++++++++++----------- 1 file changed, 11 insertions(+), 12 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b130be6a38..b62b25f688 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -177,18 +177,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (!visited(rdd)) { visited += rdd val locs = getCacheLocs(rdd) - for (p <- 0 until rdd.splits.size) { - if (locs(p) == Nil) { - for (dep <- rdd.dependencies) { - dep match { - case shufDep: ShuffleDependency[_,_] => - val mapStage = getShuffleMapStage(shufDep, stage.priority) - if (!mapStage.isAvailable) { - missing += mapStage - } - case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) - } + val atLeastOneMissing = (0 until rdd.splits.size).exists(locs(_) == Nil) + if (atLeastOneMissing) { + for (dep <- rdd.dependencies) { + dep match { + case shufDep: ShuffleDependency[_,_] => + val mapStage = getShuffleMapStage(shufDep, stage.priority) + if (!mapStage.isAvailable) { + missing += mapStage + } + case narrowDep: NarrowDependency[_] => + visit(narrowDep.rdd) } } } -- cgit v1.2.3 From 5b0fc265c2f2ce461d61904c2a4e6e47b24d2bbe Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 31 Jan 2013 17:48:39 -0800 Subject: Changed PartitionPruningRDD's split to make sure it returns the correct split index. --- core/src/main/scala/spark/Dependency.scala | 8 ++++++++ core/src/main/scala/spark/rdd/PartitionPruningRDD.scala | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index 647aee6eb5..827eac850a 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -72,6 +72,14 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo @transient val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index)) + .zipWithIndex + .map { case(split, idx) => new PruneDependency.PartitionPruningRDDSplit(idx, split) : Split } override def getParents(partitionId: Int) = List(partitions(partitionId).index) } + +object PruneDependency { + class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split { + override val index = idx + } +} diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index b8482338c6..0989b149e1 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -2,6 +2,7 @@ package spark.rdd import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} + /** * A RDD used to prune RDD partitions/splits so we can avoid launching tasks on * all partitions. An example use case: If we know the RDD is partitioned by range, @@ -15,7 +16,8 @@ class PartitionPruningRDD[T: ClassManifest]( @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) + override def compute(split: Split, context: TaskContext) = firstParent[T].iterator( + split.asInstanceOf[PruneDependency.PartitionPruningRDDSplit].parentSplit, context) override protected def getSplits = getDependencies.head.asInstanceOf[PruneDependency[T]].partitions -- cgit v1.2.3 From 6289d9654e32fc92418d41cc6e32fee30f85c833 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 31 Jan 2013 17:49:36 -0800 Subject: Removed the TODO comment from PartitionPruningRDD. --- core/src/main/scala/spark/rdd/PartitionPruningRDD.scala | 2 -- 1 file changed, 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 0989b149e1..3756870fac 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -8,8 +8,6 @@ import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. - * - * TODO: This currently doesn't give partition IDs properly! */ class PartitionPruningRDD[T: ClassManifest]( @transient prev: RDD[T], -- cgit v1.2.3 From 3446d5c8d6b385106ac85e46320d92faa8efb4e6 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 31 Jan 2013 18:02:28 -0800 Subject: SPARK-673: Capture and re-throw Python exceptions This patch alters the Python <-> executor protocol to pass on exception data when they occur in user Python code. --- .../main/scala/spark/api/python/PythonRDD.scala | 40 ++++++++++++++-------- python/pyspark/worker.py | 10 ++++-- 2 files changed, 34 insertions(+), 16 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index f43a152ca7..6b9ef62529 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -103,21 +103,30 @@ private[spark] class PythonRDD[T: ClassManifest]( private def read(): Array[Byte] = { try { - val length = stream.readInt() - if (length != -1) { - val obj = new Array[Byte](length) - stream.readFully(obj) - obj - } else { - // We've finished the data section of the output, but we can still read some - // accumulator updates; let's do that, breaking when we get EOFException - while (true) { - val len2 = stream.readInt() - val update = new Array[Byte](len2) - stream.readFully(update) - accumulator += Collections.singletonList(update) + stream.readInt() match { + case length if length > 0 => { + val obj = new Array[Byte](length) + stream.readFully(obj) + obj } - new Array[Byte](0) + case -2 => { + // Signals that an exception has been thrown in python + val exLength = stream.readInt() + val obj = new Array[Byte](exLength) + stream.readFully(obj) + throw new PythonException(new String(obj)) + } + case -1 => { + // We've finished the data section of the output, but we can still read some + // accumulator updates; let's do that, breaking when we get EOFException + while (true) { + val len2 = stream.readInt() + val update = new Array[Byte](len2) + stream.readFully(update) + accumulator += Collections.singletonList(update) + } + new Array[Byte](0) + } } } catch { case eof: EOFException => { @@ -140,6 +149,9 @@ private[spark] class PythonRDD[T: ClassManifest]( val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } +/** Thrown for exceptions in user Python code. */ +private class PythonException(msg: String) extends Exception(msg) + /** * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. * This is used by PySpark's shuffle operations. diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d33d6dd15f..9622e0cfe4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -2,6 +2,7 @@ Worker that receives input from Piped RDD. """ import sys +import traceback from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. @@ -40,8 +41,13 @@ def main(): else: dumps = dump_pickle iterator = read_from_pickle_file(sys.stdin) - for obj in func(split_index, iterator): - write_with_length(dumps(obj), old_stdout) + try: + for obj in func(split_index, iterator): + write_with_length(dumps(obj), old_stdout) + except Exception as e: + write_int(-2, old_stdout) + write_with_length(traceback.format_exc(), old_stdout) + sys.exit(-1) # Mark the beginning of the accumulators section of the output write_int(-1, old_stdout) for aid, accum in _accumulatorRegistry.items(): -- cgit v1.2.3 From c33f0ef41a1865de2bae01b52b860650d3734da4 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 31 Jan 2013 21:50:02 -0800 Subject: Some style cleanup --- core/src/main/scala/spark/api/python/PythonRDD.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 6b9ef62529..23e3149248 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -104,19 +104,17 @@ private[spark] class PythonRDD[T: ClassManifest]( private def read(): Array[Byte] = { try { stream.readInt() match { - case length if length > 0 => { + case length if length > 0 => val obj = new Array[Byte](length) stream.readFully(obj) obj - } - case -2 => { + case -2 => // Signals that an exception has been thrown in python val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) throw new PythonException(new String(obj)) - } - case -1 => { + case -1 => // We've finished the data section of the output, but we can still read some // accumulator updates; let's do that, breaking when we get EOFException while (true) { @@ -124,9 +122,8 @@ private[spark] class PythonRDD[T: ClassManifest]( val update = new Array[Byte](len2) stream.readFully(update) accumulator += Collections.singletonList(update) + new Array[Byte](0) } - new Array[Byte](0) - } } } catch { case eof: EOFException => { -- cgit v1.2.3 From 39ab83e9577a5449fb0d6ef944dffc0d7cd00b4a Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 31 Jan 2013 21:52:52 -0800 Subject: Small fix from last commit --- core/src/main/scala/spark/api/python/PythonRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 23e3149248..39758e94f4 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -122,8 +122,8 @@ private[spark] class PythonRDD[T: ClassManifest]( val update = new Array[Byte](len2) stream.readFully(update) accumulator += Collections.singletonList(update) - new Array[Byte](0) } + new Array[Byte](0) } } catch { case eof: EOFException => { -- cgit v1.2.3 From f9af9cee6fed9c6af896fb92556ad4f48c7f8e64 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 1 Feb 2013 00:02:46 -0800 Subject: Moved PruneDependency into PartitionPruningRDD.scala. --- core/src/main/scala/spark/Dependency.scala | 22 ------------------ .../main/scala/spark/rdd/PartitionPruningRDD.scala | 26 ++++++++++++++++++---- 2 files changed, 22 insertions(+), 26 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index 827eac850a..5eea907322 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -61,25 +61,3 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) } } } - - -/** - * Represents a dependency between the PartitionPruningRDD and its parent. In this - * case, the child RDD contains a subset of partitions of the parents'. - */ -class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) - extends NarrowDependency[T](rdd) { - - @transient - val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index)) - .zipWithIndex - .map { case(split, idx) => new PruneDependency.PartitionPruningRDDSplit(idx, split) : Split } - - override def getParents(partitionId: Int) = List(partitions(partitionId).index) -} - -object PruneDependency { - class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split { - override val index = idx - } -} diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 3756870fac..a50ce75171 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -1,6 +1,26 @@ package spark.rdd -import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} +import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext} + + +class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split { + override val index = idx +} + + +/** + * Represents a dependency between the PartitionPruningRDD and its parent. In this + * case, the child RDD contains a subset of partitions of the parents'. + */ +class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) + extends NarrowDependency[T](rdd) { + + @transient + val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index)) + .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split } + + override def getParents(partitionId: Int) = List(partitions(partitionId).index) +} /** @@ -15,10 +35,8 @@ class PartitionPruningRDD[T: ClassManifest]( extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { override def compute(split: Split, context: TaskContext) = firstParent[T].iterator( - split.asInstanceOf[PruneDependency.PartitionPruningRDDSplit].parentSplit, context) + split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context) override protected def getSplits = getDependencies.head.asInstanceOf[PruneDependency[T]].partitions - - override val partitioner = firstParent[T].partitioner } -- cgit v1.2.3 From f127f2ae76692b189d86b5a47293579d5657c6d5 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 1 Feb 2013 00:20:49 -0800 Subject: fixup merge (master -> driver renaming) --- core/src/main/scala/spark/storage/BlockManagerMaster.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 99324445ca..0372cb080a 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -118,7 +118,7 @@ private[spark] class BlockManagerMaster( } def getStorageStatus: Array[StorageStatus] = { - askMasterWithRetry[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray + askDriverWithReply[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray } /** Stop the driver actor, called only on the Spark driver node */ -- cgit v1.2.3 From 8a0a5ed53353ad6aa5656eb729d55ca7af2ab096 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 1 Feb 2013 00:23:38 -0800 Subject: track total partitions, in addition to cached partitions; use scala string formatting --- core/src/main/scala/spark/storage/StorageUtils.scala | 10 ++++------ core/src/main/twirl/spark/storage/rdd.scala.html | 6 +++++- core/src/main/twirl/spark/storage/rdd_table.scala.html | 6 ++++-- 3 files changed, 13 insertions(+), 9 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index ce7c067eea..5367b74bb6 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -22,12 +22,11 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, } case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numPartitions: Int, memSize: Long, diskSize: Long) { + numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) { override def toString = { import Utils.memoryBytesToString - import java.lang.{Integer => JInt} - String.format("RDD \"%s\" (%d) Storage: %s; Partitions: %d; MemorySize: %s; DiskSize: %s", name, id.asInstanceOf[JInt], - storageLevel.toString, numPartitions.asInstanceOf[JInt], memoryBytesToString(memSize), memoryBytesToString(diskSize)) + "RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id, + storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize)) } } @@ -65,9 +64,8 @@ object StorageUtils { val rdd = sc.persistentRdds(rddId) val rddName = Option(rdd.name).getOrElse(rddKey) val rddStorageLevel = rdd.getStorageLevel - //TODO get total number of partitions in rdd - RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize) + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.splits.size, memSize, diskSize) }.toArray } diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html index ac7f8c981f..d85addeb17 100644 --- a/core/src/main/twirl/spark/storage/rdd.scala.html +++ b/core/src/main/twirl/spark/storage/rdd.scala.html @@ -11,7 +11,11 @@ Storage Level: @(rddInfo.storageLevel.description)
  • - Partitions: + Cached Partitions: + @(rddInfo.numCachedPartitions) +
  • +
  • + Total Partitions: @(rddInfo.numPartitions)
  • diff --git a/core/src/main/twirl/spark/storage/rdd_table.scala.html b/core/src/main/twirl/spark/storage/rdd_table.scala.html index af801cf229..a51e64aed0 100644 --- a/core/src/main/twirl/spark/storage/rdd_table.scala.html +++ b/core/src/main/twirl/spark/storage/rdd_table.scala.html @@ -6,7 +6,8 @@ RDD Name Storage Level - Partitions + Cached Partitions + Fraction Partitions Cached Size in Memory Size on Disk @@ -21,7 +22,8 @@ @(rdd.storageLevel.description) - @rdd.numPartitions + @rdd.numCachedPartitions + @(rdd.numCachedPartitions / rdd.numPartitions.toDouble) @{Utils.memoryBytesToString(rdd.memSize)} @{Utils.memoryBytesToString(rdd.diskSize)} -- cgit v1.2.3 From 59c57e48dfb362923610785b230d5b3b56c620c3 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 1 Feb 2013 10:34:02 -0600 Subject: Stop BlockManagers metadataCleaner. --- core/src/main/scala/spark/storage/BlockManager.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index c61fd75c2b..9893e9625d 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -950,6 +950,7 @@ class BlockManager( blockInfo.clear() memoryStore.clear() diskStore.clear() + metadataCleaner.cancel() logInfo("BlockManager stopped") } } -- cgit v1.2.3 From c6190067ae40cf457b7f2e58619904b6fd2b1cb6 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 1 Feb 2013 09:55:25 -0800 Subject: remove unneeded (and unused) filter on block info --- core/src/main/scala/spark/storage/StorageUtils.scala | 2 -- 1 file changed, 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index 5367b74bb6..5f72b67b2c 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -43,8 +43,6 @@ object StorageUtils { /* Given a list of BlockStatus objets, returns information for each RDD */ def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], sc: SparkContext) : Array[RDDInfo] = { - // Find all RDD Blocks (ignore broadcast variables) - val rddBlocks = infos.filterKeys(_.startsWith("rdd")) // Group by rddId, ignore the partition name val groupedRddBlocks = infos.groupBy { case(k, v) => -- cgit v1.2.3 From 9970926ede0d5a719b8f22e97977804d3c811e97 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 1 Feb 2013 14:07:34 -0800 Subject: formatting --- core/src/main/scala/spark/RDD.scala | 2 +- core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 210404d540..010e61dfdc 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -385,7 +385,7 @@ abstract class RDD[T: ClassManifest]( val reducePartition: Iterator[T] => Option[T] = iter => { if (iter.hasNext) { Some(iter.reduceLeft(cleanF)) - }else { + } else { None } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 83641a2a84..20f2c9e489 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -32,7 +32,7 @@ private[spark] object ShuffleMapTask { return old } else { val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objOut = ser.serializeStream(new GZIPOutputStream(out)) objOut.writeObject(rdd) objOut.writeObject(dep) @@ -48,7 +48,7 @@ private[spark] object ShuffleMapTask { synchronized { val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] -- cgit v1.2.3 From 8b3041c7233011c4a96fab045a86df91eae7b6f3 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 1 Feb 2013 15:38:42 -0800 Subject: Reduced the memory usage of reduce and similar operations These operations used to wait for all the results to be available in an array on the driver program before merging them. They now merge values incrementally as they arrive. --- core/src/main/scala/spark/PairRDDFunctions.scala | 4 +- core/src/main/scala/spark/RDD.scala | 41 ++++++++++------- core/src/main/scala/spark/SparkContext.scala | 53 +++++++++++++++++++--- core/src/main/scala/spark/Utils.scala | 8 ++++ .../spark/partial/ApproximateActionListener.scala | 4 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 15 +++--- .../src/main/scala/spark/scheduler/JobResult.scala | 2 +- .../src/main/scala/spark/scheduler/JobWaiter.scala | 14 +++--- core/src/test/scala/spark/RDDSuite.scala | 12 +++-- 9 files changed, 107 insertions(+), 46 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 231e23a7de..cc3cca2571 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -465,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( val res = self.context.runJob(self, process _, Array(index), false) res(0) case None => - self.filter(_._1 == key).map(_._2).collect + self.filter(_._1 == key).map(_._2).collect() } } @@ -590,7 +590,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( var count = 0 while(iter.hasNext) { - val record = iter.next + val record = iter.next() count += 1 writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) } diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 010e61dfdc..9d6ea782bd 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -389,16 +389,18 @@ abstract class RDD[T: ClassManifest]( None } } - val options = sc.runJob(this, reducePartition) - val results = new ArrayBuffer[T] - for (opt <- options; elem <- opt) { - results += elem - } - if (results.size == 0) { - throw new UnsupportedOperationException("empty collection") - } else { - return results.reduceLeft(cleanF) + var jobResult: Option[T] = None + val mergeResult = (index: Int, taskResult: Option[T]) => { + if (taskResult != None) { + jobResult = jobResult match { + case Some(value) => Some(f(value, taskResult.get)) + case None => taskResult + } + } } + sc.runJob(this, reducePartition, mergeResult) + // Get the final result out of our Option, or throw an exception if the RDD was empty + jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) } /** @@ -408,9 +410,13 @@ abstract class RDD[T: ClassManifest]( * modify t2. */ def fold(zeroValue: T)(op: (T, T) => T): T = { + // Clone the zero value since we will also be serializing it as part of tasks + var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) val cleanOp = sc.clean(op) - val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)) - return results.fold(zeroValue)(cleanOp) + val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp) + val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult) + sc.runJob(this, foldPartition, mergeResult) + jobResult } /** @@ -422,11 +428,14 @@ abstract class RDD[T: ClassManifest]( * allocation. */ def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { + // Clone the zero value since we will also be serializing it as part of tasks + var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) val cleanSeqOp = sc.clean(seqOp) val cleanCombOp = sc.clean(combOp) - val results = sc.runJob(this, - (iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)) - return results.fold(zeroValue)(cleanCombOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) + sc.runJob(this, aggregatePartition, mergeResult) + jobResult } /** @@ -437,7 +446,7 @@ abstract class RDD[T: ClassManifest]( var result = 0L while (iter.hasNext) { result += 1L - iter.next + iter.next() } result }).sum @@ -452,7 +461,7 @@ abstract class RDD[T: ClassManifest]( var result = 0L while (iter.hasNext) { result += 1L - iter.next + iter.next() } result } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index b0d4b58240..ddbf8f95d9 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -543,26 +543,42 @@ class SparkContext( } /** - * Run a function on a given set of partitions in an RDD and return the results. This is the main - * entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies - * whether the scheduler can run the computation on the driver rather than shipping it out to the - * cluster, for short actions like first(). + * Run a function on a given set of partitions in an RDD and pass the results to the given + * handler function. This is the main entry point for all actions in Spark. The allowLocal + * flag specifies whether the scheduler can run the computation on the driver rather than + * shipping it out to the cluster, for short actions like first(). */ def runJob[T, U: ClassManifest]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], - allowLocal: Boolean - ): Array[U] = { + allowLocal: Boolean, + resultHandler: (Int, U) => Unit) { val callSite = Utils.getSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal) + val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() result } + /** + * Run a function on a given set of partitions in an RDD and return the results as an array. The + * allowLocal flag specifies whether the scheduler can run the computation on the driver rather + * than shipping it out to the cluster, for short actions like first(). + */ + def runJob[T, U: ClassManifest]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + allowLocal: Boolean + ): Array[U] = { + val results = new Array[U](partitions.size) + runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res) + results + } + /** * Run a job on a given set of partitions of an RDD, but take a function of type * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. @@ -590,6 +606,29 @@ class SparkContext( runJob(rdd, func, 0 until rdd.splits.size, false) } + /** + * Run a job on all partitions in an RDD and pass the results to a handler function. + */ + def runJob[T, U: ClassManifest]( + rdd: RDD[T], + processPartition: (TaskContext, Iterator[T]) => U, + resultHandler: (Int, U) => Unit) + { + runJob[T, U](rdd, processPartition, 0 until rdd.splits.size, false, resultHandler) + } + + /** + * Run a job on all partitions in an RDD and pass the results to a handler function. + */ + def runJob[T, U: ClassManifest]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + resultHandler: (Int, U) => Unit) + { + val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) + runJob[T, U](rdd, processFunc, 0 until rdd.splits.size, false, resultHandler) + } + /** * Run a job that can return approximate results. */ diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 1e58d01273..28d643abca 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -12,6 +12,7 @@ import scala.io.Source import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder import scala.Some +import spark.serializer.SerializerInstance /** * Various utility methods used by Spark. @@ -446,4 +447,11 @@ private object Utils extends Logging { socket.close() portBound } + + /** + * Clone an object using a Spark serializer. + */ + def clone[T](value: T, serializer: SerializerInstance): T = { + serializer.deserialize[T](serializer.serialize(value)) + } } diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala index 42f46e06ed..24b4909380 100644 --- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala +++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala @@ -32,7 +32,7 @@ private[spark] class ApproximateActionListener[T, U, R]( if (finishedTasks == totalTasks) { // If we had already returned a PartialResult, set its final value resultObject.foreach(r => r.setFinalValue(evaluator.currentResult())) - // Notify any waiting thread that may have called getResult + // Notify any waiting thread that may have called awaitResult this.notifyAll() } } @@ -49,7 +49,7 @@ private[spark] class ApproximateActionListener[T, U, R]( * Waits for up to timeout milliseconds since the listener was created and then returns a * PartialResult with the result so far. This may be complete if the whole job is done. */ - def getResult(): PartialResult[R] = synchronized { + def awaitResult(): PartialResult[R] = synchronized { val finishTime = startTime + timeout while (true) { val time = System.currentTimeMillis() diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 14f61f7e87..908a22b2df 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -203,18 +203,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: String, - allowLocal: Boolean) - : Array[U] = + allowLocal: Boolean, + resultHandler: (Int, U) => Unit) { if (partitions.size == 0) { - return new Array[U](0) + return } - val waiter = new JobWaiter(partitions.size) + val waiter = new JobWaiter(partitions.size, resultHandler) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)) - waiter.getResult() match { - case JobSucceeded(results: Seq[_]) => - return results.asInstanceOf[Seq[U]].toArray + waiter.awaitResult() match { + case JobSucceeded => {} case JobFailed(exception: Exception) => logInfo("Failed to run " + callSite) throw exception @@ -233,7 +232,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.splits.size).toArray eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener)) - return listener.getResult() // Will throw an exception if the job fails + return listener.awaitResult() // Will throw an exception if the job fails } /** diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala index c4a74e526f..654131ee84 100644 --- a/core/src/main/scala/spark/scheduler/JobResult.scala +++ b/core/src/main/scala/spark/scheduler/JobResult.scala @@ -5,5 +5,5 @@ package spark.scheduler */ private[spark] sealed trait JobResult -private[spark] case class JobSucceeded(results: Seq[_]) extends JobResult +private[spark] case object JobSucceeded extends JobResult private[spark] case class JobFailed(exception: Exception) extends JobResult diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala index b3d4feebe5..3cc6a86345 100644 --- a/core/src/main/scala/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/spark/scheduler/JobWaiter.scala @@ -3,10 +3,12 @@ package spark.scheduler import scala.collection.mutable.ArrayBuffer /** - * An object that waits for a DAGScheduler job to complete. + * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their + * results to the given handler function. */ -private[spark] class JobWaiter(totalTasks: Int) extends JobListener { - private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null) +private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit) + extends JobListener { + private var finishedTasks = 0 private var jobFinished = false // Is the job as a whole finished (succeeded or failed)? @@ -17,11 +19,11 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener { if (jobFinished) { throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") } - taskResults(index) = result + resultHandler(index, result.asInstanceOf[T]) finishedTasks += 1 if (finishedTasks == totalTasks) { jobFinished = true - jobResult = JobSucceeded(taskResults) + jobResult = JobSucceeded this.notifyAll() } } @@ -38,7 +40,7 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener { } } - def getResult(): JobResult = synchronized { + def awaitResult(): JobResult = synchronized { while (!jobFinished) { this.wait() } diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index ed03e65153..95d2e62730 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -12,9 +12,9 @@ class RDDSuite extends FunSuite with LocalSparkContext { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) - assert(dups.distinct.count === 4) - assert(dups.distinct().collect === dups.distinct.collect) - assert(dups.distinct(2).collect === dups.distinct.collect) + assert(dups.distinct().count === 4) + assert(dups.distinct().collect === dups.distinct().collect) + assert(dups.distinct(2).collect === dups.distinct().collect) assert(nums.reduce(_ + _) === 10) assert(nums.fold(0)(_ + _) === 10) assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) @@ -31,6 +31,10 @@ class RDDSuite extends FunSuite with LocalSparkContext { case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) } assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) + + intercept[UnsupportedOperationException] { + nums.filter(_ > 5).reduce(_ + _) + } } test("SparkContext.union") { @@ -164,7 +168,7 @@ class RDDSuite extends FunSuite with LocalSparkContext { // Note that split number starts from 0, so > 8 means only 10th partition left. val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) assert(prunedRdd.splits.size === 1) - val prunedData = prunedRdd.collect + val prunedData = prunedRdd.collect() assert(prunedData.size === 1) assert(prunedData(0) === 10) } -- cgit v1.2.3 From 12c1eb47568060efac57d6df7df7e5704a8d3fab Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 1 Feb 2013 21:21:44 -0600 Subject: Reduce the amount of duplicate logging Akka does to stdout. Given we have Akka logging go through SLF4j to log4j, we don't need all the extra noise of Akka's stdout logger that is supposedly only used during Akka init time but seems to continue logging lots of noisy network events that we either don't care about or are in the log4j logs anyway. See: http://doc.akka.io/docs/akka/2.0/general/configuration.html # Log level for the very basic logger activated during AkkaApplication startup # Options: ERROR, WARNING, INFO, DEBUG # stdout-loglevel = "WARNING" --- core/src/main/scala/spark/util/AkkaUtils.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core') diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index e0fdeffbc4..e43fbd6b1c 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -30,6 +30,7 @@ private[spark] object AkkaUtils { val akkaConf = ConfigFactory.parseString(""" akka.daemonic = on akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] + akka.stdout-loglevel = "ERROR" akka.actor.provider = "akka.remote.RemoteActorRefProvider" akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" akka.remote.log-remote-lifecycle-events = on -- cgit v1.2.3 From ae26911ec0d768dcdae8b7d706ca4544e36535e6 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 1 Feb 2013 21:07:24 -0800 Subject: Add back test for distinct without parens --- core/src/test/scala/spark/RDDSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 95d2e62730..89a3687386 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -12,7 +12,8 @@ class RDDSuite extends FunSuite with LocalSparkContext { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) - assert(dups.distinct().count === 4) + assert(dups.distinct().count() === 4) + assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses? assert(dups.distinct().collect === dups.distinct().collect) assert(dups.distinct(2).collect === dups.distinct().collect) assert(nums.reduce(_ + _) === 10) -- cgit v1.2.3 From 1fd5ee323d127499bb3f173d4142c37532ec29b2 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Fri, 1 Feb 2013 22:33:38 -0800 Subject: Code review changes: add sc.stop; style of multiline comments; parens on procedure calls. --- .../scala/spark/scheduler/DAGSchedulerSuite.scala | 69 +++++++++++++++------- 1 file changed, 47 insertions(+), 22 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index c31e2e7064..adce1f38bb 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -31,7 +31,7 @@ import spark.TaskEndReason import spark.{FetchFailed, Success} /** - * Tests for DAGScheduler. These tests directly call the event processing functinos in DAGScheduler + * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler * rather than spawning an event loop thread as happens in the real code. They use EasyMock * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead @@ -56,29 +56,34 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar var schedulerThread: Thread = null var schedulerException: Throwable = null - /** Set of EasyMock argument matchers that match a TaskSet for a given RDD. + /** + * Set of EasyMock argument matchers that match a TaskSet for a given RDD. * We cache these so we do not create duplicate matchers for the same RDD. * This allows us to easily setup a sequence of expectations for task sets for * that RDD. */ val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher] - /** Set of cache locations to return from our mock BlockManagerMaster. + /** + * Set of cache locations to return from our mock BlockManagerMaster. * Keys are (rdd ID, partition ID). Anything not present will return an empty * list of cache locations silently. */ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] - /** JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which + /** + * JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which * will only submit one job) from needing to explicitly track it. */ var lastJobWaiter: JobWaiter = null - /** Tell EasyMockSugar what mock objects we want to be configured by expecting {...} + /** + * Tell EasyMockSugar what mock objects we want to be configured by expecting {...} * and whenExecuting {...} */ implicit val mocks = MockObjects(taskScheduler, blockManagerMaster) - /** Utility function to reset mocks and set expectations on them. EasyMock wants mock objects + /** + * Utility function to reset mocks and set expectations on them. EasyMock wants mock objects * to be reset after each time their expectations are set, and we tend to check mock object * calls over a single call to DAGScheduler. * @@ -115,17 +120,21 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar whenExecuting { scheduler.stop() } + sc.stop() System.clearProperty("spark.master.port") } def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) - /** Type of RDD we use for testing. Note that we should never call the real RDD compute methods. - * This is a pair RDD type so it can always be used in ShuffleDependencies. */ + /** + * Type of RDD we use for testing. Note that we should never call the real RDD compute methods. + * This is a pair RDD type so it can always be used in ShuffleDependencies. + */ type MyRDD = RDD[(Int, Int)] - /** Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and + /** + * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and * preferredLocations (if any) that are passed to them. They are deliberately not executable * so we can test that DAGScheduler does not try to execute RDDs locally. */ @@ -150,7 +159,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } } - /** EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task + /** + * EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task * is from a particular RDD. */ def taskSetForRdd(rdd: MyRDD): TaskSet = { @@ -172,7 +182,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar return null } - /** Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from + /** + * Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from * cacheLocations. */ def expectGetLocations(): Unit = { @@ -197,7 +208,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar }).anyTimes() } - /** Process the supplied event as if it were the top of the DAGScheduler event queue, expecting + /** + * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting * the scheduler not to exit. * * After processing the event, submit waiting stages as is done on most iterations of the @@ -208,7 +220,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar scheduler.submitWaitingStages() } - /** Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be + /** + * Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be * called from a resetExpecting { ... } block. * * Returns a easymock Capture that will contain the task set after the stage is submitted. @@ -220,7 +233,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar return taskSetCapture } - /** Expect the supplied code snippet to submit a stage for the specified RDD. + /** + * Expect the supplied code snippet to submit a stage for the specified RDD. * Return the resulting TaskSet. First marks all the tasks are belonging to the * current MapOutputTracker generation. */ @@ -239,7 +253,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar return taskSet } - /** Send the given CompletionEvent messages for the tasks in the TaskSet. */ + /** + * Send the given CompletionEvent messages for the tasks in the TaskSet. + */ def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { @@ -249,7 +265,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } } - /** Assert that the supplied TaskSet has exactly the given preferredLocations. */ + /** + * Assert that the supplied TaskSet has exactly the given preferredLocations. + */ def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) { assert(locations.size === taskSet.tasks.size) for ((expectLocs, taskLocs) <- @@ -258,7 +276,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } } - /** When we submit dummy Jobs, this is the compute function we supply. Except in a local test + /** + * When we submit dummy Jobs, this is the compute function we supply. Except in a local test * below, we do not expect this function to ever be executed; instead, we will return results * directly through CompletionEvents. */ @@ -266,8 +285,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar it.next._1.asInstanceOf[Int] - /** Start a job to compute the given RDD. Returns the JobWaiter that will - * collect the result of the job via callbacks from DAGScheduler. */ + /** + * Start a job to compute the given RDD. Returns the JobWaiter that will + * collect the result of the job via callbacks from DAGScheduler. + */ def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): JobWaiter = { val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int]( rdd, @@ -281,7 +302,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar return waiter } - /** Assert that a job we started has failed. */ + /** + * Assert that a job we started has failed. + */ def expectJobException(waiter: JobWaiter = lastJobWaiter) { waiter.getResult match { case JobSucceeded(_) => fail() @@ -289,7 +312,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } } - /** Assert that a job we started has succeeded and has the given result. */ + /** + * Assert that a job we started has succeeded and has the given result. + */ def expectJobResult(expected: Array[Int], waiter: JobWaiter = lastJobWaiter) { waiter.getResult match { case JobSucceeded(answer) => @@ -500,7 +525,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar )) } val recomputeOne = interceptStage(shuffleOneRdd) { - scheduler.resubmitFailedStages + scheduler.resubmitFailedStages() } val recomputeTwo = interceptStage(shuffleTwoRdd) { respondToTaskSet(recomputeOne, List( -- cgit v1.2.3 From 28e0cb9f312b7fb1b0236fd15ba0dd2f423e826d Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 2 Feb 2013 01:11:37 -0600 Subject: Fix createActorSystem not actually using the systemName parameter. This meant all system names were "spark", which worked, but didn't lead to the most intuitive log output. This fixes createActorSystem to use the passed system name, and refactors Master/Worker to encapsulate their system/actor names instead of having the clients guess at them. Note that the driver system name, "spark", is left as is, and is still repeated a few times, but that seems like a separate issue. --- .../scala/spark/deploy/LocalSparkCluster.scala | 38 +++++--------- .../main/scala/spark/deploy/client/Client.scala | 13 ++--- .../main/scala/spark/deploy/master/Master.scala | 24 +++++++-- .../main/scala/spark/deploy/worker/Worker.scala | 58 ++++++++++------------ .../scala/spark/storage/BlockManagerMaster.scala | 2 - core/src/main/scala/spark/util/AkkaUtils.scala | 6 ++- 6 files changed, 68 insertions(+), 73 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index 2836574ecb..22319a96ca 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -18,35 +18,23 @@ import scala.collection.mutable.ArrayBuffer private[spark] class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { - val localIpAddress = Utils.localIpAddress + private val localIpAddress = Utils.localIpAddress + private val masterActorSystems = ArrayBuffer[ActorSystem]() + private val workerActorSystems = ArrayBuffer[ActorSystem]() - var masterActor : ActorRef = _ - var masterActorSystem : ActorSystem = _ - var masterPort : Int = _ - var masterUrl : String = _ - - val workerActorSystems = ArrayBuffer[ActorSystem]() - val workerActors = ArrayBuffer[ActorRef]() - - def start() : String = { + def start(): String = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ - val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0) - masterActorSystem = actorSystem - masterUrl = "spark://" + localIpAddress + ":" + masterPort - masterActor = masterActorSystem.actorOf( - Props(new Master(localIpAddress, masterPort, 0)), name = "Master") + val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0) + masterActorSystems += masterSystem + val masterUrl = "spark://" + localIpAddress + ":" + masterPort - /* Start the Slaves */ + /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("sparkWorker" + workerNum, localIpAddress, 0) - workerActorSystems += actorSystem - val actor = actorSystem.actorOf( - Props(new Worker(localIpAddress, boundPort, 0, coresPerWorker, memoryPerWorker, masterUrl)), - name = "Worker") - workerActors += actor + val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker, + memoryPerWorker, masterUrl, null, Some(workerNum)) + workerActorSystems += workerSystem } return masterUrl @@ -57,7 +45,7 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I // Stop the workers before the master so they don't get upset that it disconnected workerActorSystems.foreach(_.shutdown()) workerActorSystems.foreach(_.awaitTermination()) - masterActorSystem.shutdown() - masterActorSystem.awaitTermination() + masterActorSystems.foreach(_.shutdown()) + masterActorSystems.foreach(_.awaitTermination()) } } diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala index 90fe9508cd..a63eee1233 100644 --- a/core/src/main/scala/spark/deploy/client/Client.scala +++ b/core/src/main/scala/spark/deploy/client/Client.scala @@ -9,6 +9,7 @@ import spark.{SparkException, Logging} import akka.remote.RemoteClientLifeCycleEvent import akka.remote.RemoteClientShutdown import spark.deploy.RegisterJob +import spark.deploy.master.Master import akka.remote.RemoteClientDisconnected import akka.actor.Terminated import akka.dispatch.Await @@ -24,26 +25,18 @@ private[spark] class Client( listener: ClientListener) extends Logging { - val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r - var actor: ActorRef = null var jobId: String = null - if (MASTER_REGEX.unapplySeq(masterUrl) == None) { - throw new SparkException("Invalid master URL: " + masterUrl) - } - class ClientActor extends Actor with Logging { var master: ActorRef = null var masterAddress: Address = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times override def preStart() { - val Seq(masterHost, masterPort) = MASTER_REGEX.unapplySeq(masterUrl).get - logInfo("Connecting to master spark://" + masterHost + ":" + masterPort) - val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort) + logInfo("Connecting to master " + masterUrl) try { - master = context.actorFor(akkaUrl) + master = context.actorFor(Master.toAkkaUrl(masterUrl)) masterAddress = master.path.address master ! RegisterJob(jobDescription) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index c618e87cdd..92e7914b1b 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -262,11 +262,29 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } private[spark] object Master { + private val systemName = "sparkMaster" + private val actorName = "Master" + private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r + def main(argStrings: Array[String]) { val args = new MasterArguments(argStrings) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port) - val actor = actorSystem.actorOf( - Props(new Master(args.ip, boundPort, args.webUiPort)), name = "Master") + val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort) actorSystem.awaitTermination() } + + /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */ + def toAkkaUrl(sparkUrl: String): String = { + sparkUrl match { + case sparkUrlRegex(host, port) => + "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName) + case _ => + throw new SparkException("Invalid master URL: " + sparkUrl) + } + } + + def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int) = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) + val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName) + (actorSystem, boundPort) + } } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 8b41620d98..2219dd6262 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -1,7 +1,7 @@ package spark.deploy.worker import scala.collection.mutable.{ArrayBuffer, HashMap} -import akka.actor.{ActorRef, Props, Actor} +import akka.actor.{ActorRef, Props, Actor, ActorSystem} import spark.{Logging, Utils} import spark.util.AkkaUtils import spark.deploy._ @@ -13,6 +13,7 @@ import akka.remote.RemoteClientDisconnected import spark.deploy.RegisterWorker import spark.deploy.LaunchExecutor import spark.deploy.RegisterWorkerFailed +import spark.deploy.master.Master import akka.actor.Terminated import java.io.File @@ -27,7 +28,6 @@ private[spark] class Worker( extends Actor with Logging { val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs - val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r var master: ActorRef = null var masterWebUiUrl : String = "" @@ -48,11 +48,7 @@ private[spark] class Worker( def memoryFree: Int = memory - memoryUsed def createWorkDir() { - workDir = if (workDirPath != null) { - new File(workDirPath) - } else { - new File(sparkHome, "work") - } + workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work")) try { if (!workDir.exists() && !workDir.mkdirs()) { logError("Failed to create work directory " + workDir) @@ -68,8 +64,7 @@ private[spark] class Worker( override def preStart() { logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( ip, port, cores, Utils.memoryMegabytesToString(memory))) - val envVar = System.getenv("SPARK_HOME") - sparkHome = new File(if (envVar == null) "." else envVar) + sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse(".")) logInfo("Spark home: " + sparkHome) createWorkDir() connectToMaster() @@ -77,24 +72,15 @@ private[spark] class Worker( } def connectToMaster() { - masterUrl match { - case MASTER_REGEX(masterHost, masterPort) => { - logInfo("Connecting to master spark://" + masterHost + ":" + masterPort) - val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort) - try { - master = context.actorFor(akkaUrl) - master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing - } catch { - case e: Exception => - logError("Failed to connect to master", e) - System.exit(1) - } - } - - case _ => - logError("Invalid master URL: " + masterUrl) + logInfo("Connecting to master " + masterUrl) + try { + master = context.actorFor(Master.toAkkaUrl(masterUrl)) + master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress) + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.watch(master) // Doesn't work with remote actors, but useful for testing + } catch { + case e: Exception => + logError("Failed to connect to master", e) System.exit(1) } } @@ -183,11 +169,19 @@ private[spark] class Worker( private[spark] object Worker { def main(argStrings: Array[String]) { val args = new WorkerArguments(argStrings) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port) - val actor = actorSystem.actorOf( - Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory, - args.master, args.workDir)), - name = "Worker") + val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores, + args.memory, args.master, args.workDir) actorSystem.awaitTermination() } + + def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int, + masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = { + // The LocalSparkCluster runs multiple local sparkWorkerX actor systems + val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") + val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) + val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory, + masterUrl, workDir)), name = "Worker") + (actorSystem, boundPort) + } + } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 36398095a2..7be6b9fa87 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -27,8 +27,6 @@ private[spark] class BlockManagerMaster( val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager" - val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager" - val DEFAULT_MANAGER_IP: String = Utils.localHostName() val timeout = 10.seconds var driverActor: ActorRef = { diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index e0fdeffbc4..3a3626e8a0 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -18,9 +18,13 @@ import java.util.concurrent.TimeoutException * Various utility classes for working with Akka. */ private[spark] object AkkaUtils { + /** * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the * ActorSystem itself and its port (which is hard to get from Akka). + * + * Note: the `name` parameter is important, as even if a client sends a message to right + * host + port, if the system name is incorrect, Akka will drop the message. */ def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt @@ -41,7 +45,7 @@ private[spark] object AkkaUtils { akka.actor.default-dispatcher.throughput = %d """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize)) - val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) + val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader) // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a // hack because Akka doesn't let you figure out the port through the public API yet. -- cgit v1.2.3 From 696eec32c982ca516c506de33f383a173bcbd131 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 2 Feb 2013 02:03:26 -0600 Subject: Move executorMemory up into SchedulerBackend. --- .../scala/spark/scheduler/cluster/SchedulerBackend.scala | 12 ++++++++++++ .../scheduler/cluster/SparkDeploySchedulerBackend.scala | 9 --------- .../spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala | 10 ---------- .../scala/spark/scheduler/mesos/MesosSchedulerBackend.scala | 10 ---------- 4 files changed, 12 insertions(+), 29 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala index ddcd64d7c6..9ac875de3a 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala @@ -1,5 +1,7 @@ package spark.scheduler.cluster +import spark.Utils + /** * A backend interface for cluster scheduling systems that allows plugging in different ones under * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as @@ -11,5 +13,15 @@ private[spark] trait SchedulerBackend { def reviveOffers(): Unit def defaultParallelism(): Int + // Memory used by each executor (in megabytes) + protected val executorMemory = { + // TODO: Might need to add some extra memory for the non-heap parts of the JVM + Option(System.getProperty("spark.executor.memory")) + .orElse(Option(System.getenv("SPARK_MEM"))) + .map(Utils.memoryStringToMb) + .getOrElse(512) + } + + // TODO: Probably want to add a killTask too } diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 2f7099c5b9..59ff8bcb90 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -20,15 +20,6 @@ private[spark] class SparkDeploySchedulerBackend( val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt - // Memory used by each executor (in megabytes) - val executorMemory = { - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - Option(System.getProperty("spark.executor.memory")) - .orElse(Option(System.getenv("SPARK_MEM"))) - .map(Utils.memoryStringToMb) - .getOrElse(512) - } - override def start() { super.start() diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index 7bf56a05d6..b481ec0a72 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -35,16 +35,6 @@ private[spark] class CoarseMesosSchedulerBackend( val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures - // Memory used by each executor (in megabytes) - val executorMemory = { - if (System.getenv("SPARK_MEM") != null) { - Utils.memoryStringToMb(System.getenv("SPARK_MEM")) - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - } else { - 512 - } - } - // Lock used to wait for scheduler to be registered var isRegistered = false val registeredLock = new Object() diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index eab1c60e0b..5c8b531de3 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -29,16 +29,6 @@ private[spark] class MesosSchedulerBackend( with MScheduler with Logging { - // Memory used by each executor (in megabytes) - val EXECUTOR_MEMORY = { - if (System.getenv("SPARK_MEM") != null) { - Utils.memoryStringToMb(System.getenv("SPARK_MEM")) - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - } else { - 512 - } - } - // Lock used to wait for scheduler to be registered var isRegistered = false val registeredLock = new Object() -- cgit v1.2.3 From cae8a6795c7f454b74c8d3c4425a6ced151d6d9b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 2 Feb 2013 02:15:39 -0600 Subject: Fix dangling old variable names. --- core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index 5c8b531de3..300766d0f5 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -79,7 +79,7 @@ private[spark] class MesosSchedulerBackend( val memory = Resource.newBuilder() .setName("mem") .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(EXECUTOR_MEMORY).build()) + .setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build()) .build() val command = CommandInfo.newBuilder() .setValue(execScript) @@ -151,7 +151,7 @@ private[spark] class MesosSchedulerBackend( def enoughMemory(o: Offer) = { val mem = getResource(o.getResourcesList, "mem") val slaveId = o.getSlaveId.getValue - mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId) + mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId) } for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) { -- cgit v1.2.3 From 7aba123f0c0fd024105462b3a0b203cd357c67e9 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 2 Feb 2013 13:53:28 -0600 Subject: Further simplify checking for Nil. --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b62b25f688..2a646dd0f5 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -176,9 +176,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd - val locs = getCacheLocs(rdd) - val atLeastOneMissing = (0 until rdd.splits.size).exists(locs(_) == Nil) - if (atLeastOneMissing) { + if (getCacheLocs(rdd).contains(Nil)) { for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => -- cgit v1.2.3 From 34a7bcdb3a19deed18b25225daf47ff22ee20869 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 2 Feb 2013 19:40:30 -0800 Subject: Formatting --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 8cfc08e5ac..2a35915560 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -205,8 +205,9 @@ class DAGScheduler( missing.toList } - /** Returns (and does not) submit a JobSubmitted event suitable to run a given job, and - * a JobWaiter whose getResult() method will return the result of the job when it is complete. + /** + * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a + * JobWaiter whose getResult() method will return the result of the job when it is complete. * * The job is assumed to have at least one partition; zero partition jobs should be handled * without a JobSubmitted event. @@ -308,7 +309,8 @@ class DAGScheduler( return false } - /** Resubmit any failed stages. Ordinarily called after a small amount of time has passed since + /** + * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since * the last fetch failure. */ private[scheduler] def resubmitFailedStages() { @@ -321,7 +323,8 @@ class DAGScheduler( } } - /** Check for waiting or failed stages which are now eligible for resubmission. + /** + * Check for waiting or failed stages which are now eligible for resubmission. * Ordinarily run on every iteration of the event loop. */ private[scheduler] def submitWaitingStages() { @@ -366,9 +369,9 @@ class DAGScheduler( // the same time, so we want to make sure we've identified all the reduce tasks that depend // on the failed node. if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { - resubmitFailedStages + resubmitFailedStages() } else { - submitWaitingStages + submitWaitingStages() } } } -- cgit v1.2.3 From 8fbd5380b7f36842297f624bad3a2513f7eca47b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 Feb 2013 06:44:49 +0000 Subject: Fetch fewer objects in PySpark's take() method. --- core/src/main/scala/spark/api/python/PythonRDD.scala | 11 +++++++++-- python/pyspark/rdd.py | 4 ++++ 2 files changed, 13 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 39758e94f4..ab8351e55e 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -238,6 +238,11 @@ private[spark] object PythonRDD { } def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { + import scala.collection.JavaConverters._ + writeIteratorToPickleFile(items.asScala, filename) + } + + def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) for (item <- items) { writeAsPickle(item, file) @@ -245,8 +250,10 @@ private[spark] object PythonRDD { file.close() } - def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] = - rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head + def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = { + implicit val cm : ClassManifest[T] = rdd.elementClassManifest + rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator + } } private object Pickle { diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index fb144bc45d..4cda6cf661 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -372,6 +372,10 @@ class RDD(object): items = [] for partition in range(self._jrdd.splits().size()): iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) + # Each item in the iterator is a string, Python object, batch of + # Python objects. Regardless, it is sufficient to take `num` + # of these objects in order to collect `num` Python objects: + iterator = iterator.take(num) items.extend(self._collect_iterator_through_file(iterator)) if len(items) >= num: break -- cgit v1.2.3 From 9163c3705d98ca19c09fe5618e347b9d20f88f63 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 2 Feb 2013 23:34:47 -0800 Subject: Formatting --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 2a35915560..edbfd1c45f 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -265,7 +265,8 @@ class DAGScheduler( return listener.awaitResult() // Will throw an exception if the job fails } - /** Process one event retrieved from the event queue. + /** + * Process one event retrieved from the event queue. * Returns true if we should stop the event loop. */ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { -- cgit v1.2.3 From aa4ee1e9e5485c1b96474e704c76225a2b8a7da9 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 4 Feb 2013 11:06:31 -0800 Subject: Fix failing test --- core/src/test/scala/spark/MapOutputTrackerSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index f4e7ec39fe..dd19442dcb 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -79,8 +79,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("remote fetch") { try { System.clearProperty("spark.driver.host") // In case some previous test had set it - val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("test", "localhost", 0) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0) System.setProperty("spark.driver.port", boundPort.toString) val masterTracker = new MapOutputTracker(actorSystem, true) val slaveTracker = new MapOutputTracker(actorSystem, false) -- cgit v1.2.3 From f6ec547ea7b56ee607a4c2a69206f8952318eaf1 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 4 Feb 2013 13:14:54 -0800 Subject: Small fix to test for distinct --- core/src/test/scala/spark/RDDSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 89a3687386..fe7deb10d6 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -14,7 +14,7 @@ class RDDSuite extends FunSuite with LocalSparkContext { val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) assert(dups.distinct().count() === 4) assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses? - assert(dups.distinct().collect === dups.distinct().collect) + assert(dups.distinct.collect === dups.distinct().collect) assert(dups.distinct(2).collect === dups.distinct().collect) assert(nums.reduce(_ + _) === 10) assert(nums.fold(0)(_ + _) === 10) -- cgit v1.2.3 From 8bd0e888f377f13ac239df4ffd49fc666095e764 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 5 Feb 2013 17:50:25 -0600 Subject: Inline mergePair to look more like the narrow dep branch. No functionality changes, I think this is just more consistent given mergePair isn't called multiple times/recursive. Also added a comment to explain the usual case of having two parent RDDs. --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 8fafd27bb6..4893fe8d78 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -84,6 +84,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size + // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] def getSeq(k: K): Seq[ArrayBuffer[Any]] = { val seq = map.get(k) @@ -104,13 +105,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) } case ShuffleCoGroupSplitDep(shuffleId) => { // Read map outputs of shuffle - def mergePair(pair: (K, Seq[Any])) { - val mySeq = getSeq(pair._1) - for (v <- pair._2) - mySeq(depNum) += v - } val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair) + for ((k, vs) <- fetcher.fetch[K, Seq[Any]](shuffleId, split.index)) { + getSeq(k)(depNum) ++= vs + } } } JavaConversions.mapAsScalaMap(map).iterator -- cgit v1.2.3 From 1ba3393ceb5709620a28b8bc01826153993fc444 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 5 Feb 2013 17:56:50 -0600 Subject: Increase DriverSuite timeout. --- core/src/test/scala/spark/DriverSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala index 342610e1dd..5e84b3a66a 100644 --- a/core/src/test/scala/spark/DriverSuite.scala +++ b/core/src/test/scala/spark/DriverSuite.scala @@ -9,10 +9,11 @@ import org.scalatest.time.SpanSugar._ class DriverSuite extends FunSuite with Timeouts { test("driver should exit after finishing") { + assert(System.getenv("SPARK_HOME") != null) // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) forAll(masters) { (master: String) => - failAfter(10 seconds) { + failAfter(30 seconds) { Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME"))) } -- cgit v1.2.3 From 0e19093fd89ec9740f98cdcffd1ec09f4faf2490 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 5 Feb 2013 18:58:00 -0600 Subject: Handle Terminated to avoid endless DeathPactExceptions. Credit to Roland Kuhn, Akka's tech lead, for pointing out this various obvious fix, but StandaloneExecutorBackend.preStart's catch block would never (ever) get hit, because all of the operation's in preStart are async. So, the System.exit in the catch block was skipped, and instead Akka was sending Terminated messages which, since we didn't handle, it turned into DeathPactException, which started a postRestart/preStart infinite loop. --- .../main/scala/spark/deploy/worker/Worker.scala | 7 ++---- .../spark/executor/StandaloneExecutorBackend.scala | 25 ++++++++++------------ 2 files changed, 13 insertions(+), 19 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 8b41620d98..48177a638a 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -1,19 +1,16 @@ package spark.deploy.worker import scala.collection.mutable.{ArrayBuffer, HashMap} -import akka.actor.{ActorRef, Props, Actor} +import akka.actor.{ActorRef, Props, Actor, Terminated} import spark.{Logging, Utils} import spark.util.AkkaUtils import spark.deploy._ -import akka.remote.RemoteClientLifeCycleEvent +import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} import java.text.SimpleDateFormat import java.util.Date -import akka.remote.RemoteClientShutdown -import akka.remote.RemoteClientDisconnected import spark.deploy.RegisterWorker import spark.deploy.LaunchExecutor import spark.deploy.RegisterWorkerFailed -import akka.actor.Terminated import java.io.File private[spark] class Worker( diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index e45288ff53..224c126fdd 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -4,16 +4,15 @@ import java.nio.ByteBuffer import spark.Logging import spark.TaskState.TaskState import spark.util.AkkaUtils -import akka.actor.{ActorRef, Actor, Props} +import akka.actor.{ActorRef, Actor, Props, Terminated} +import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} import java.util.concurrent.{TimeUnit, ThreadPoolExecutor, SynchronousQueue} -import akka.remote.RemoteClientLifeCycleEvent import spark.scheduler.cluster._ import spark.scheduler.cluster.RegisteredExecutor import spark.scheduler.cluster.LaunchTask import spark.scheduler.cluster.RegisterExecutorFailed import spark.scheduler.cluster.RegisterExecutor - private[spark] class StandaloneExecutorBackend( executor: Executor, driverUrl: String, @@ -27,17 +26,11 @@ private[spark] class StandaloneExecutorBackend( var driver: ActorRef = null override def preStart() { - try { - logInfo("Connecting to driver: " + driverUrl) - driver = context.actorFor(driverUrl) - driver ! RegisterExecutor(executorId, hostname, cores) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(driver) // Doesn't work with remote actors, but useful for testing - } catch { - case e: Exception => - logError("Failed to connect to driver", e) - System.exit(1) - } + logInfo("Connecting to driver: " + driverUrl) + driver = context.actorFor(driverUrl) + driver ! RegisterExecutor(executorId, hostname, cores) + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.watch(driver) // Doesn't work with remote actors, but useful for testing } override def receive = { @@ -52,6 +45,10 @@ private[spark] class StandaloneExecutorBackend( case LaunchTask(taskDesc) => logInfo("Got assigned task " + taskDesc.taskId) executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) + + case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => + logError("Driver terminated or disconnected! Shutting down.") + System.exit(1) } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { -- cgit v1.2.3